use arc_swap::ArcSwap;
use std::sync::{Arc, Mutex};
use crate::config::models::{JokowayConfig, ServiceProtocol};
use jokoway_rules::parse_rule;
use jokoway_transformer::{
RequestTransformer, ResponseTransformer, parse_response_transformers, parse_transformers,
};
pub const HTTP_PROTOCOLS: [ServiceProtocol; 3] = [
ServiceProtocol::Http,
ServiceProtocol::Ws,
ServiceProtocol::Grpc,
];
pub const HTTPS_PROTOCOLS: [ServiceProtocol; 3] = [
ServiceProtocol::Https,
ServiceProtocol::Wss,
ServiceProtocol::Grpcs,
];
pub const ALL_PROTOCOLS: [ServiceProtocol; 6] = [
ServiceProtocol::Http,
ServiceProtocol::Https,
ServiceProtocol::Ws,
ServiceProtocol::Wss,
ServiceProtocol::Grpc,
ServiceProtocol::Grpcs,
];
use jokoway_rules::Matcher;
pub struct RuntimeRoute {
pub matcher: Box<dyn Matcher>,
pub priority: i32,
pub max_retries: u32,
pub req_transformer: Option<Arc<dyn RequestTransformer>>,
pub res_transformer: Option<Arc<dyn ResponseTransformer>>,
}
pub struct RuntimeService {
pub name: String,
pub host: Arc<str>, pub protocols: Vec<ServiceProtocol>,
pub routes: Vec<RuntimeRoute>,
pub config: Arc<crate::config::models::Service>,
}
pub struct ServiceManager {
services: ArcSwap<Vec<RuntimeService>>,
callbacks: Mutex<Vec<Box<dyn Fn() + Send + Sync>>>,
}
use crate::error::JokowayError;
fn compile_service(
service: &Arc<crate::config::models::Service>,
) -> Result<RuntimeService, JokowayError> {
let total_rules = service.routes.len();
let mut routes = Vec::with_capacity(total_rules);
let service_max_retries = service.max_retries.unwrap_or(1);
for route_config in &service.routes {
let matcher = match parse_rule(&route_config.rule) {
Ok(m) => m,
Err(e) => {
return Err(JokowayError::Config(format!(
"Rule parse error [service={}, route={}, rule={}]: {}",
service.name, route_config.name, route_config.rule, e
)));
}
};
let req_transformer = route_config
.request_transformer
.as_ref()
.and_then(|t_str| match parse_transformers(t_str) {
Ok(t) => Some(Arc::from(t)),
Err(e) => {
log::error!(
"Request transformer parse error [service={}, route={}, transformer={}]: {}",
service.name,
route_config.name,
t_str,
e
);
None
}
});
let res_transformer = route_config
.response_transformer
.as_ref()
.and_then(|t_str| match parse_response_transformers(t_str) {
Ok(t) => Some(Arc::from(t)),
Err(e) => {
log::error!(
"Response transformer parse error [service={}, route={}, transformer={}]: {}",
service.name,
route_config.name,
t_str,
e
);
None
}
});
let effective_max_retries = route_config
.max_retries
.or(Some(service_max_retries))
.unwrap_or(1);
routes.push(RuntimeRoute {
matcher,
priority: route_config.priority.unwrap_or(0),
max_retries: effective_max_retries,
req_transformer,
res_transformer,
});
}
routes.sort_by(|a, b| b.priority.cmp(&a.priority));
let runtime_service = RuntimeService {
name: service.name.clone(),
host: Arc::from(service.host.as_str()),
protocols: service.protocols.clone(),
routes,
config: service.clone(),
};
use jokoway_rules::registry::register_hosts;
let mut hosts = std::collections::HashSet::new();
for route in &runtime_service.routes {
for host in route.matcher.get_hosts() {
hosts.insert(host);
}
}
register_hosts(hosts);
Ok(runtime_service)
}
impl ServiceManager {
pub fn new(config: Arc<JokowayConfig>) -> Result<Self, JokowayError> {
let services = Self::compile_services(&config);
Ok(Self {
services: ArcSwap::from_pointee(services),
callbacks: Mutex::new(Vec::new()),
})
}
fn compile_services(config: &JokowayConfig) -> Vec<RuntimeService> {
let mut services = Vec::with_capacity(config.services.len());
for svc_config in &config.services {
match compile_service(svc_config) {
Ok(svc) => services.push(svc),
Err(e) => {
log::error!("Failed to compile service {}: {}", svc_config.name, e);
}
}
}
services
}
pub fn get_indices_for_protocols(&self, allowed_protocols: &[ServiceProtocol]) -> Vec<usize> {
let services = self.services.load();
services
.iter()
.enumerate()
.filter_map(|(idx, svc)| {
if svc.protocols.is_empty() {
Some(idx)
} else if svc.protocols.iter().any(|p| allowed_protocols.contains(p)) {
Some(idx)
} else {
None
}
})
.collect()
}
pub fn get_all(&self) -> arc_swap::Guard<Arc<Vec<RuntimeService>>> {
self.services.load()
}
pub fn add_services_changed_callback<F>(&self, callback: F)
where
F: Fn() + Send + Sync + 'static,
{
self.callbacks.lock().unwrap().push(Box::new(callback));
}
fn notify_callbacks(&self) {
let callbacks = self.callbacks.lock().unwrap();
for callback in callbacks.iter() {
callback();
}
}
pub fn list_services(&self) -> Vec<RuntimeService> {
let services = self.services.load();
services
.iter()
.map(|svc| RuntimeService {
name: svc.name.clone(),
host: svc.host.clone(),
protocols: svc.protocols.clone(),
routes: Vec::new(), config: svc.config.clone(),
})
.collect()
}
pub fn verify_service(&self, name: &str) -> bool {
let services = self.services.load();
services.iter().any(|svc| svc.name == name)
}
pub fn add_service(&self, service: crate::config::models::Service) -> Result<(), JokowayError> {
let service = Arc::new(service);
if self.verify_service(&service.name) {
return Err(JokowayError::Config(format!(
"Service {} already exists",
service.name
)));
}
let runtime_service = compile_service(&service)?;
self.services.rcu(|old| {
let mut next = Vec::with_capacity(old.len() + 1);
for svc in old.iter() {
next.push(RuntimeService {
name: svc.name.clone(),
host: svc.host.clone(),
protocols: svc.protocols.clone(),
routes: svc
.routes
.iter()
.map(|r| RuntimeRoute {
matcher: r.matcher.clone_box(),
priority: r.priority,
max_retries: r.max_retries,
req_transformer: r.req_transformer.clone(),
res_transformer: r.res_transformer.clone(),
})
.collect(),
config: svc.config.clone(),
});
}
next.push(RuntimeService {
name: runtime_service.name.clone(),
host: runtime_service.host.clone(),
protocols: runtime_service.protocols.clone(),
routes: runtime_service
.routes
.iter()
.map(|r| RuntimeRoute {
matcher: r.matcher.clone_box(),
priority: r.priority,
max_retries: r.max_retries,
req_transformer: r.req_transformer.clone(),
res_transformer: r.res_transformer.clone(),
})
.collect(),
config: runtime_service.config.clone(),
});
next
});
self.notify_callbacks();
log::info!("Added service: {}", service.name);
Ok(())
}
pub fn update_service(
&self,
name: &str,
service: crate::config::models::Service,
) -> Result<(), JokowayError> {
let service = Arc::new(service);
if !self.verify_service(name) {
return Err(JokowayError::Config(format!(
"Service {} does not exist",
name
)));
}
let runtime_service = compile_service(&service)?;
self.services.rcu(|old| {
let mut next = Vec::with_capacity(old.len());
for svc in old.iter() {
if svc.name == name {
next.push(RuntimeService {
name: runtime_service.name.clone(),
host: runtime_service.host.clone(),
protocols: runtime_service.protocols.clone(),
routes: runtime_service
.routes
.iter()
.map(|r| RuntimeRoute {
matcher: r.matcher.clone_box(),
priority: r.priority,
max_retries: r.max_retries,
req_transformer: r.req_transformer.clone(),
res_transformer: r.res_transformer.clone(),
})
.collect(),
config: runtime_service.config.clone(),
});
} else {
next.push(RuntimeService {
name: svc.name.clone(),
host: svc.host.clone(),
protocols: svc.protocols.clone(),
routes: svc
.routes
.iter()
.map(|r| RuntimeRoute {
matcher: r.matcher.clone_box(),
priority: r.priority,
max_retries: r.max_retries,
req_transformer: r.req_transformer.clone(),
res_transformer: r.res_transformer.clone(),
})
.collect(),
config: svc.config.clone(),
});
}
}
next
});
self.notify_callbacks();
log::info!("Updated service: {}", name);
Ok(())
}
pub fn remove_service(&self, name: &str) -> Result<(), JokowayError> {
if !self.verify_service(name) {
log::warn!("Service {} does not exist, skipping remove", name);
return Ok(());
}
self.services.rcu(|old| {
let mut next = Vec::with_capacity(old.len());
for svc in old.iter() {
if svc.name != name {
next.push(RuntimeService {
name: svc.name.clone(),
host: svc.host.clone(),
protocols: svc.protocols.clone(),
routes: svc
.routes
.iter()
.map(|r| RuntimeRoute {
matcher: r.matcher.clone_box(),
priority: r.priority,
max_retries: r.max_retries,
req_transformer: r.req_transformer.clone(),
res_transformer: r.res_transformer.clone(),
})
.collect(),
config: svc.config.clone(),
});
}
}
next
});
self.notify_callbacks();
log::info!("Removed service: {}", name);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::models::{JokowayConfig, Route, Service};
#[test]
fn test_protocol_filtering() {
let config = JokowayConfig {
services: vec![
Service {
name: "http_only".to_string(),
host: "http_backend".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![],
..Default::default()
},
Service {
name: "https_only".to_string(),
host: "https_backend".to_string(),
protocols: vec![ServiceProtocol::Https],
routes: vec![],
..Default::default()
},
Service {
name: "dual_protocol".to_string(),
host: "dual_backend".to_string(),
protocols: vec![ServiceProtocol::Http, ServiceProtocol::Https],
routes: vec![],
..Default::default()
},
Service {
name: "no_protocol".to_string(),
host: "default_backend".to_string(),
protocols: vec![],
routes: vec![],
..Default::default()
},
]
.into_iter()
.map(Arc::new)
.collect(),
..Default::default()
};
let manager =
ServiceManager::new(Arc::new(config)).expect("Failed to create ServiceManager");
let all_services = manager.get_all();
let get_names = |indices: Vec<usize>| -> Vec<String> {
indices
.iter()
.map(|&i| all_services[i].name.clone())
.collect()
};
let http_indices = manager.get_indices_for_protocols(&HTTP_PROTOCOLS);
let http_names = get_names(http_indices.clone());
assert_eq!(http_indices.len(), 3); assert!(http_names.contains(&"http_only".to_string()));
assert!(http_names.contains(&"dual_protocol".to_string()));
assert!(http_names.contains(&"no_protocol".to_string()));
assert!(!http_names.contains(&"https_only".to_string()));
let https_indices = manager.get_indices_for_protocols(&HTTPS_PROTOCOLS);
let https_names = get_names(https_indices.clone());
assert_eq!(https_indices.len(), 3); assert!(https_names.contains(&"https_only".to_string()));
assert!(https_names.contains(&"dual_protocol".to_string()));
assert!(https_names.contains(&"no_protocol".to_string()));
assert!(!https_names.contains(&"http_only".to_string()));
let all_indices = manager.get_indices_for_protocols(&ALL_PROTOCOLS);
assert_eq!(all_indices.len(), 4);
}
#[test]
fn test_service_compilation() {
let config = JokowayConfig {
services: vec![Service {
name: "test_service".to_string(),
host: "test_backend".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![
Route {
name: "test_route_1".to_string(),
rule: "Host(`example.com`)".to_string(),
priority: Some(10),
..Default::default()
},
Route {
name: "test_route_2".to_string(),
rule: "Host(`api.example.com`)".to_string(),
priority: Some(5),
..Default::default()
},
],
..Default::default()
}]
.into_iter()
.map(Arc::new)
.collect(),
..Default::default()
};
let manager =
ServiceManager::new(Arc::new(config)).expect("Failed to create ServiceManager");
let services = manager.get_all();
assert_eq!(services.len(), 1);
assert_eq!(services[0].name, "test_service");
assert_eq!(services[0].routes.len(), 2);
assert_eq!(services[0].routes[0].priority, 10);
assert_eq!(services[0].routes[1].priority, 5);
}
}