use std::sync::Arc;
use std::time::Duration;
use epics_pva_rs::client::PvaClient;
use epics_pva_rs::server_native::{CompositeSource, PvaServer, PvaServerConfig};
use super::channel_cache::{ChannelCache, DEFAULT_CLEANUP_INTERVAL};
use super::control::ControlSource;
use super::error::{GwError, GwResult};
use super::source::GatewayChannelSource;
struct UpstreamTenant {
label: String,
client: Arc<PvaClient>,
}
struct DownstreamSpec {
label: String,
config: PvaServerConfig,
upstream_labels: Vec<String>,
control_prefix: Option<String>,
}
pub struct MultiTenantPvaGatewayBuilder {
upstreams: Vec<UpstreamTenant>,
downstreams: Vec<DownstreamSpec>,
cleanup_interval: Duration,
connect_timeout: Duration,
max_cache_entries: usize,
max_subscribers: usize,
}
impl Default for MultiTenantPvaGatewayBuilder {
fn default() -> Self {
Self::new()
}
}
impl MultiTenantPvaGatewayBuilder {
pub fn new() -> Self {
Self {
upstreams: Vec::new(),
downstreams: Vec::new(),
cleanup_interval: DEFAULT_CLEANUP_INTERVAL,
connect_timeout: Duration::from_secs(5),
max_cache_entries: super::channel_cache::DEFAULT_MAX_ENTRIES,
max_subscribers: 100_000,
}
}
pub fn add_upstream(mut self, label: impl Into<String>, client: Arc<PvaClient>) -> Self {
self.upstreams.push(UpstreamTenant {
label: label.into(),
client,
});
self
}
pub fn add_downstream(
mut self,
label: impl Into<String>,
config: PvaServerConfig,
upstream_labels: &[&str],
control_prefix: Option<String>,
) -> Self {
self.downstreams.push(DownstreamSpec {
label: label.into(),
config,
upstream_labels: upstream_labels.iter().map(|s| (*s).to_string()).collect(),
control_prefix,
});
self
}
pub fn cleanup_interval(mut self, d: Duration) -> Self {
self.cleanup_interval = d;
self
}
pub fn connect_timeout(mut self, d: Duration) -> Self {
self.connect_timeout = d;
self
}
pub fn max_cache_entries(mut self, n: usize) -> Self {
self.max_cache_entries = n;
self
}
pub fn max_subscribers(mut self, n: usize) -> Self {
self.max_subscribers = n;
self
}
pub fn start(self) -> GwResult<MultiTenantPvaGateway> {
if self.upstreams.is_empty() {
return Err(GwError::Other(
"MultiTenantPvaGatewayBuilder: at least one upstream required".into(),
));
}
if self.downstreams.is_empty() {
return Err(GwError::Other(
"MultiTenantPvaGatewayBuilder: at least one downstream required \
(a gateway with no listeners would resolve no clients)"
.into(),
));
}
for (i, a) in self.upstreams.iter().enumerate() {
for b in &self.upstreams[i + 1..] {
if a.label == b.label {
return Err(GwError::Other(format!(
"duplicate upstream label '{}'",
a.label
)));
}
}
}
for (i, a) in self.downstreams.iter().enumerate() {
for b in &self.downstreams[i + 1..] {
if a.label == b.label {
return Err(GwError::Other(format!(
"duplicate downstream label '{}'",
a.label
)));
}
}
if a.upstream_labels.is_empty() {
return Err(GwError::Other(format!(
"downstream '{}' must reference at least one upstream",
a.label
)));
}
}
let mut caches: Vec<(String, Arc<ChannelCache>)> = Vec::with_capacity(self.upstreams.len());
for u in &self.upstreams {
let c = ChannelCache::with_max_entries(
u.client.clone(),
self.cleanup_interval,
self.max_cache_entries,
);
caches.push((u.label.clone(), c));
}
let mut servers: Vec<DownstreamHandle> = Vec::with_capacity(self.downstreams.len());
for ds in self.downstreams {
let mut sources: Vec<(String, Arc<ChannelCache>)> = Vec::new();
for needed in &ds.upstream_labels {
let cache = caches
.iter()
.find(|(lbl, _)| lbl == needed)
.map(|(_, c)| c.clone())
.ok_or_else(|| {
GwError::Other(format!(
"downstream '{}' references unknown upstream label '{needed}'",
ds.label
))
})?;
sources.push((needed.clone(), cache));
}
let composite = CompositeSource::new();
let mut first_gw_source: Option<GatewayChannelSource> = None;
let mut first_cache: Option<Arc<ChannelCache>> = None;
for (i, (label, cache)) in sources.iter().enumerate() {
let mut src = GatewayChannelSource::new(cache.clone());
src.connect_timeout = self.connect_timeout;
src.max_subscribers = self.max_subscribers;
if first_gw_source.is_none() {
first_gw_source = Some(src.clone());
first_cache = Some(cache.clone());
}
let order = i as i32; let name = format!("gateway:{label}");
composite
.add_source(&name, Arc::new(src), order)
.map_err(|e| {
GwError::Other(format!(
"downstream '{}' source '{name}' registration: {e}",
ds.label
))
})?;
}
if let (Some(prefix), Some(gw_src), Some(cache)) =
(ds.control_prefix.as_ref(), first_gw_source, first_cache)
{
if !prefix.is_empty() {
let control = ControlSource::new(prefix, cache, gw_src);
composite
.add_source("__gw_control", Arc::new(control), -100)
.map_err(|e| {
GwError::Other(format!(
"downstream '{}' control source registration: {e}",
ds.label
))
})?;
}
}
let server = PvaServer::start(composite, ds.config)?;
servers.push(DownstreamHandle {
label: ds.label,
server,
});
}
Ok(MultiTenantPvaGateway { caches, servers })
}
}
struct DownstreamHandle {
label: String,
server: PvaServer,
}
pub struct MultiTenantPvaGateway {
caches: Vec<(String, Arc<ChannelCache>)>,
servers: Vec<DownstreamHandle>,
}
impl MultiTenantPvaGateway {
pub fn downstream_count(&self) -> usize {
self.servers.len()
}
pub fn upstream_count(&self) -> usize {
self.caches.len()
}
pub fn downstream(&self, label: &str) -> Option<&PvaServer> {
self.servers
.iter()
.find(|h| h.label == label)
.map(|h| &h.server)
}
pub fn upstream_cache(&self, label: &str) -> Option<&Arc<ChannelCache>> {
self.caches
.iter()
.find(|(lbl, _)| lbl == label)
.map(|(_, c)| c)
}
pub fn stop_all(&self) {
for h in &self.servers {
h.server.stop();
}
}
}