use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use epics_base_rs::error::CaError;
use epics_base_rs::server::database::PvDatabase;
use epics_base_rs::server::pv::{WriteContext, WriteHook};
use epics_base_rs::types::EpicsValue;
use epics_ca_rs::client::{CaChannel, CaClient};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use crate::error::{BridgeError, BridgeResult};
use super::access::AccessConfig;
use super::cache::{PvCache, PvState};
use super::putlog::{PutLog, PutOutcome};
use super::pvlist::PvList;
use super::stats::Stats;
struct UpstreamSubscription {
channel: Arc<CaChannel>,
task: JoinHandle<()>,
asg: Option<String>,
asl: i32,
}
#[derive(Clone)]
struct WriteHookEnv {
read_only: bool,
access: Arc<ArcSwap<AccessConfig>>,
pvlist: Arc<ArcSwap<PvList>>,
putlog: Option<Arc<PutLog>>,
stats: Arc<Stats>,
beacon_anomaly: Arc<super::beacon::BeaconAnomaly>,
}
pub struct UpstreamManagerConfig {
pub cache: Arc<RwLock<PvCache>>,
pub shadow_db: Arc<PvDatabase>,
pub access: Arc<ArcSwap<AccessConfig>>,
pub pvlist: Arc<ArcSwap<PvList>>,
pub putlog: Option<Arc<PutLog>>,
pub stats: Arc<Stats>,
pub read_only: bool,
pub beacon_anomaly: Arc<super::beacon::BeaconAnomaly>,
}
pub struct UpstreamManager {
client: Arc<CaClient>,
cache: Arc<RwLock<PvCache>>,
shadow_db: Arc<PvDatabase>,
write_env: WriteHookEnv,
subs: parking_lot::Mutex<HashMap<String, UpstreamSubscription>>,
pending: parking_lot::Mutex<HashMap<String, Arc<tokio::sync::Notify>>>,
}
impl UpstreamManager {
pub async fn new(cfg: UpstreamManagerConfig) -> BridgeResult<Self> {
let client = CaClient::new()
.await
.map_err(|e| BridgeError::PutRejected(format!("CaClient init: {e}")))?;
Ok(Self {
client: Arc::new(client),
cache: cfg.cache,
shadow_db: cfg.shadow_db,
write_env: WriteHookEnv {
read_only: cfg.read_only,
access: cfg.access,
pvlist: cfg.pvlist,
putlog: cfg.putlog,
stats: cfg.stats,
beacon_anomaly: cfg.beacon_anomaly,
},
subs: parking_lot::Mutex::new(HashMap::new()),
pending: parking_lot::Mutex::new(HashMap::new()),
})
}
pub fn subscription_count(&self) -> usize {
self.subs.lock().len()
}
pub fn is_subscribed(&self, name: &str) -> bool {
self.subs.lock().contains_key(name)
}
pub async fn ensure_subscribed(
&self,
upstream_name: &str,
asg: Option<String>,
asl: i32,
) -> BridgeResult<()> {
if self.subs.lock().contains_key(upstream_name) {
return Ok(());
}
enum Decision {
WaitFor(Arc<tokio::sync::Notify>),
Owner(Arc<tokio::sync::Notify>),
}
let decision = {
let mut pending = self.pending.lock();
if let Some(existing) = pending.get(upstream_name) {
Decision::WaitFor(existing.clone())
} else {
let n = Arc::new(tokio::sync::Notify::new());
pending.insert(upstream_name.to_string(), n.clone());
Decision::Owner(n)
}
};
let dedup_notify = match decision {
Decision::WaitFor(n) => {
n.notified().await;
let already = self.subs.lock().contains_key(upstream_name);
return if already {
Ok(())
} else {
Err(BridgeError::PutRejected(format!(
"upstream subscribe failed (peer creator): {upstream_name}"
)))
};
}
Decision::Owner(n) => n,
};
struct PendingGuard<'a> {
owner: &'a UpstreamManager,
key: &'a str,
notify: Arc<tokio::sync::Notify>,
}
impl Drop for PendingGuard<'_> {
fn drop(&mut self) {
self.owner.pending.lock().remove(self.key);
self.notify.notify_waiters();
}
}
let _guard = PendingGuard {
owner: self,
key: upstream_name,
notify: dedup_notify.clone(),
};
{
let mut cache = self.cache.write().await;
let entry = cache.get_or_create(upstream_name);
entry.write().await.set_state(PvState::Connecting);
}
let channel = Arc::new(self.client.create_channel(upstream_name));
let initial_value = match tokio::time::timeout(Duration::from_millis(500), channel.get())
.await
{
Ok(Ok((_dbf, v))) => v,
Ok(Err(e)) => {
tracing::info!(
pv = upstream_name,
error = %e,
"ca-gateway-rs: DBR negotiation get failed; using Double(0.0) placeholder"
);
EpicsValue::Double(0.0)
}
Err(_) => {
tracing::info!(
pv = upstream_name,
"ca-gateway-rs: DBR negotiation get timed out; using Double(0.0) placeholder"
);
EpicsValue::Double(0.0)
}
};
let hook = build_write_hook(
upstream_name.to_string(),
channel.clone(),
asg.clone(),
asl,
self.write_env.clone(),
);
self.shadow_db.remove_simple_pv(upstream_name).await;
if let Err(e) = self
.shadow_db
.add_pv_with_hook(upstream_name, initial_value, hook)
.await
{
return Err(BridgeError::PutRejected(format!(
"shadow PV register failed: {e}"
)));
}
let mut monitor = match channel.subscribe().await {
Ok(m) => m,
Err(e) => {
self.shadow_db.remove_simple_pv(upstream_name).await;
return Err(BridgeError::PutRejected(format!("subscribe failed: {e}")));
}
};
let cache_clone = self.cache.clone();
let db_clone = self.shadow_db.clone();
let channel_for_task = channel.clone();
let stats_for_task = self.write_env.stats.clone();
let beacon_anomaly_for_task = self.write_env.beacon_anomaly.clone();
let name = upstream_name.to_string();
let task = tokio::spawn(async move {
let mut backoff = Duration::from_millis(250);
let max_backoff = Duration::from_secs(30);
loop {
while let Some(result) = monitor.recv().await {
let snapshot = match result {
Ok(s) => s,
Err(_) => continue,
};
stats_for_task.record_event();
let mut transitioned_from_disconnect = false;
if let Some(entry_arc) = cache_clone.read().await.get(&name) {
let mut entry = entry_arc.write().await;
let was_disconnect = matches!(entry.state, PvState::Disconnect);
if matches!(entry.state, PvState::Connecting | PvState::Disconnect) {
let next = if entry.subscriber_count() > 0 {
PvState::Active
} else {
PvState::Inactive
};
entry.set_state(next);
}
entry.update(snapshot.clone());
transitioned_from_disconnect = was_disconnect;
}
if transitioned_from_disconnect {
beacon_anomaly_for_task.request();
}
let _ = db_clone
.put_pv_and_post(&name, snapshot.value.clone())
.await;
backoff = Duration::from_millis(250);
}
if let Some(entry_arc) = cache_clone.read().await.get(&name) {
entry_arc.write().await.set_state(PvState::Disconnect);
}
let _ = db_clone.post_alarm(&name, 3, 0).await;
tokio::time::sleep(backoff).await;
backoff = std::cmp::min(backoff * 2, max_backoff);
if cache_clone.read().await.get(&name).is_none() {
return;
}
match channel_for_task.subscribe().await {
Ok(new_monitor) => {
monitor = new_monitor;
}
Err(_) => {
continue;
}
}
}
});
self.subs.lock().insert(
upstream_name.to_string(),
UpstreamSubscription {
channel,
task,
asg,
asl,
},
);
Ok(())
}
pub async fn unsubscribe(&self, upstream_name: &str) {
let removed = self.subs.lock().remove(upstream_name);
if let Some(sub) = removed {
sub.task.abort();
}
let _ = self.shadow_db.remove_simple_pv(upstream_name).await;
}
pub async fn put(&self, upstream_name: &str, value: &EpicsValue) -> BridgeResult<()> {
let channel_for_op = self
.subs
.lock()
.get(upstream_name)
.map(|s| s.channel.clone());
if let Some(ch) = channel_for_op {
return ch
.put(value)
.await
.map_err(|e| BridgeError::PutRejected(format!("upstream put: {e}")));
}
let channel = self.client.create_channel(upstream_name);
channel
.put(value)
.await
.map_err(|e| BridgeError::PutRejected(format!("upstream put: {e}")))
}
pub async fn get(&self, upstream_name: &str) -> BridgeResult<EpicsValue> {
let channel_for_op = self
.subs
.lock()
.get(upstream_name)
.map(|s| s.channel.clone());
if let Some(ch) = channel_for_op {
let (_dbf, value) = ch
.get()
.await
.map_err(|e| BridgeError::PutRejected(format!("upstream get: {e}")))?;
return Ok(value);
}
let channel = self.client.create_channel(upstream_name);
let (_dbf, value) = channel
.get()
.await
.map_err(|e| BridgeError::PutRejected(format!("upstream get: {e}")))?;
Ok(value)
}
pub fn asg_for(&self, upstream_name: &str) -> Option<(Option<String>, i32)> {
self.subs
.lock()
.get(upstream_name)
.map(|s| (s.asg.clone(), s.asl))
}
pub async fn sweep_orphaned(&self) {
let cache = self.cache.read().await;
let orphans: Vec<String> = self
.subs
.lock()
.keys()
.filter(|name| cache.get(name).is_none())
.cloned()
.collect();
drop(cache);
for name in orphans {
self.unsubscribe(&name).await;
}
}
pub async fn shutdown(&self) {
let drained: Vec<UpstreamSubscription> =
self.subs.lock().drain().map(|(_, sub)| sub).collect();
for sub in drained {
sub.task.abort();
}
self.client.shutdown().await;
}
}
fn build_write_hook(
pv_name: String,
channel: Arc<CaChannel>,
asg: Option<String>,
asl: i32,
env: WriteHookEnv,
) -> WriteHook {
Arc::new(move |new_value: EpicsValue, ctx: WriteContext| {
let pv_name = pv_name.clone();
let channel = channel.clone();
let asg = asg.clone();
let env = env.clone();
Box::pin(async move {
let value_str = format_value_for_audit(&new_value, 256);
if env.read_only {
env.stats.record_readonly_reject();
log_denial(&env, &ctx, &pv_name, &value_str).await;
return Err(CaError::ReadOnlyField(format!(
"{pv_name} (gateway in read-only mode)"
)));
}
let pvlist = env.pvlist.load_full();
if pvlist.is_host_denied(&pv_name, &ctx.host) {
env.stats.record_readonly_reject();
log_denial(&env, &ctx, &pv_name, &value_str).await;
return Err(CaError::PutDisabled(format!(
"{pv_name} (host {} denied by pvlist)",
ctx.host
)));
}
let access = env.access.load_full();
if ctx.user.is_empty() && access.has_rules() {
env.stats.record_readonly_reject();
log_denial(&env, &ctx, &pv_name, &value_str).await;
return Err(CaError::ReadOnlyField(format!(
"{pv_name} (no client identity)"
)));
}
let asg_ref = asg.as_deref().unwrap_or("DEFAULT");
if !access.can_write(asg_ref, asl, &ctx.user, &ctx.host) {
env.stats.record_readonly_reject();
log_denial(&env, &ctx, &pv_name, &value_str).await;
return Err(CaError::ReadOnlyField(format!(
"{pv_name} (asg {asg_ref}, user {})",
ctx.user
)));
}
let result = channel.put(&new_value).await;
if let Some(pl) = &env.putlog {
let outcome = if result.is_ok() {
PutOutcome::Ok
} else {
PutOutcome::Failed
};
if let Err(e) = pl
.log(&ctx.user, &ctx.host, &pv_name, &value_str, outcome)
.await
{
tracing::warn!(
target: "ca_gateway::putlog",
error = %e,
"ca-gateway-rs: putlog write failed"
);
}
}
if result.is_ok() {
env.stats.record_put();
}
result
})
})
}
async fn log_denial(env: &WriteHookEnv, ctx: &WriteContext, pv: &str, value: &str) {
if let Some(pl) = &env.putlog
&& let Err(e) = pl
.log(&ctx.user, &ctx.host, pv, value, PutOutcome::Denied)
.await
{
tracing::warn!(
target: "ca_gateway::putlog",
error = %e,
"ca-gateway-rs: putlog write failed"
);
}
}
fn format_value_for_audit(v: &EpicsValue, max_len: usize) -> String {
const HEAD_PEEK_ELEMS: usize = 32;
let truncated;
let v_for_format: &EpicsValue = match v {
EpicsValue::ShortArray(arr) if arr.len() > HEAD_PEEK_ELEMS => {
truncated = EpicsValue::ShortArray(arr[..HEAD_PEEK_ELEMS].to_vec());
&truncated
}
EpicsValue::FloatArray(arr) if arr.len() > HEAD_PEEK_ELEMS => {
truncated = EpicsValue::FloatArray(arr[..HEAD_PEEK_ELEMS].to_vec());
&truncated
}
EpicsValue::EnumArray(arr) if arr.len() > HEAD_PEEK_ELEMS => {
truncated = EpicsValue::EnumArray(arr[..HEAD_PEEK_ELEMS].to_vec());
&truncated
}
EpicsValue::DoubleArray(arr) if arr.len() > HEAD_PEEK_ELEMS => {
truncated = EpicsValue::DoubleArray(arr[..HEAD_PEEK_ELEMS].to_vec());
&truncated
}
EpicsValue::LongArray(arr) if arr.len() > HEAD_PEEK_ELEMS => {
truncated = EpicsValue::LongArray(arr[..HEAD_PEEK_ELEMS].to_vec());
&truncated
}
EpicsValue::CharArray(arr) if arr.len() > max_len => {
truncated = EpicsValue::CharArray(arr[..max_len].to_vec());
&truncated
}
EpicsValue::StringArray(arr) if arr.len() > HEAD_PEEK_ELEMS => {
truncated = EpicsValue::StringArray(arr[..HEAD_PEEK_ELEMS].to_vec());
&truncated
}
_ => v,
};
let s = format!("{v_for_format}");
if s.len() <= max_len {
s
} else {
let mut end = max_len.saturating_sub(3);
while !s.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &s[..end])
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_env() -> WriteHookEnv {
WriteHookEnv {
read_only: false,
access: Arc::new(ArcSwap::from_pointee(AccessConfig::allow_all())),
pvlist: Arc::new(ArcSwap::from_pointee(PvList::new())),
putlog: None,
stats: Arc::new(Stats::new("gw:".into())),
beacon_anomaly: Arc::new(crate::ca_gateway::beacon::BeaconAnomaly::new()),
}
}
#[tokio::test]
async fn manager_construct() {
let cache = Arc::new(RwLock::new(PvCache::new()));
let db = Arc::new(PvDatabase::new());
let env = dummy_env();
let mgr = UpstreamManager::new(UpstreamManagerConfig {
cache,
shadow_db: db,
access: env.access.clone(),
pvlist: env.pvlist.clone(),
putlog: None,
stats: env.stats.clone(),
read_only: false,
beacon_anomaly: env.beacon_anomaly.clone(),
})
.await;
assert!(mgr.is_ok());
let mgr = mgr.unwrap();
assert_eq!(mgr.subscription_count(), 0);
assert!(!mgr.is_subscribed("ANY"));
}
#[test]
fn _entry_imports() {
let _ = super::super::cache::GwPvEntry::new_connecting("X");
}
}