use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use parking_lot::Mutex;
use tokio::sync::RwLock;
use tokio::sync::mpsc;
use epics_base_rs::server::access_security::{AccessLevel, AccessSecurityConfig};
use epics_pva_rs::client::PvaClient;
use epics_pva_rs::pvdata::{FieldDesc, PvField};
use epics_pva_rs::server::native_source::AcfCell;
use epics_pva_rs::server_native::source::{AccessChecked, ChannelContext, ChannelSource};
use super::channel_cache::ChannelCache;
#[derive(Debug, Clone)]
pub struct RawEvent {
pub body: bytes::Bytes,
pub byte_order: epics_pva_rs::proto::ByteOrder,
}
pub type AsgResolver = Arc<dyn Fn(&str) -> String + Send + Sync>;
fn default_asg_resolver() -> AsgResolver {
Arc::new(|_pv| "DEFAULT".to_string())
}
#[derive(Clone)]
pub struct GatewayChannelSource {
cache: Arc<ChannelCache>,
pub connect_timeout: Duration,
pub subscriber_queue: usize,
pub rpc_timeout: Duration,
pub max_subscribers: usize,
subscriber_count: Arc<AtomicUsize>,
upstream_pool: Arc<Mutex<HashMap<(String, String), Arc<PvaClient>>>>,
acf: AcfCell,
asg_resolver: Arc<RwLock<AsgResolver>>,
gate: epics_base_rs::server::access_security::AccessGate,
}
impl GatewayChannelSource {
pub fn new(cache: Arc<ChannelCache>) -> Self {
let acf: AcfCell = Arc::new(RwLock::new(None));
let asg_resolver = Arc::new(RwLock::new(default_asg_resolver()));
let gate = Self::build_gate(acf.clone(), asg_resolver.clone());
Self {
cache,
connect_timeout: Duration::from_secs(5),
subscriber_queue: 64,
rpc_timeout: Duration::from_secs(30),
max_subscribers: 100_000,
subscriber_count: Arc::new(AtomicUsize::new(0)),
upstream_pool: Arc::new(Mutex::new(HashMap::new())),
acf,
asg_resolver,
gate,
}
}
fn build_gate(
acf: AcfCell,
asg_resolver: Arc<RwLock<AsgResolver>>,
) -> epics_base_rs::server::access_security::AccessGate {
use epics_base_rs::server::access_security::{AccessGate, AsgAslResolver};
let resolver: AsgAslResolver = Arc::new(move |pv_name| {
let asg_resolver = asg_resolver.clone();
Box::pin(async move {
let g = asg_resolver.read().await;
let asg = (g)(&pv_name);
(asg, 0u8)
})
});
AccessGate::required(acf, resolver)
}
pub async fn set_asg_resolver(&self, resolver: Option<AsgResolver>) {
*self.asg_resolver.write().await = resolver.unwrap_or_else(default_asg_resolver);
self.gate.bump_acl_version();
}
pub async fn set_acf(&self, cfg: Option<AccessSecurityConfig>) {
*self.acf.write().await = cfg;
self.gate.bump_acl_version();
}
#[cfg(test)]
async fn acl_level(&self, pv: &str, ctx: &ChannelContext) -> AccessLevel {
let guard = self.acf.read().await;
match *guard {
None => AccessLevel::ReadWrite,
Some(ref cfg) => {
let resolver = self.asg_resolver.read().await;
let asg = (resolver)(pv);
cfg.check_access_method(&asg, &ctx.host, &ctx.account, 0, &ctx.method, "")
}
}
}
fn upstream_client_for(&self, ctx: &ChannelContext) -> Arc<PvaClient> {
if ctx.account.is_empty() || ctx.method == "anonymous" {
return self.cache.client().clone();
}
let key = (ctx.account.clone(), ctx.method.clone());
let mut pool = self.upstream_pool.lock();
if let Some(c) = pool.get(&key) {
return c.clone();
}
let client = Arc::new(
PvaClient::builder()
.user(ctx.account.clone())
.host(ctx.host.clone())
.build(),
);
pool.insert(key, client.clone());
client
}
pub fn cache(&self) -> &Arc<ChannelCache> {
&self.cache
}
pub async fn cached_entry_count(&self) -> usize {
self.cache.entry_count().await
}
pub fn live_subscribers(&self) -> usize {
self.subscriber_count.load(Ordering::Relaxed)
}
}
impl ChannelSource for GatewayChannelSource {
fn access(&self) -> &epics_base_rs::server::access_security::AccessGate {
&self.gate
}
async fn list_pvs(&self) -> Vec<String> {
self.cache.names().await
}
async fn has_pv(&self, name: &str) -> bool {
self.cache.lookup(name, self.connect_timeout).await.is_ok()
}
async fn get_introspection(&self, name: &str) -> Option<FieldDesc> {
let entry = self.cache.lookup(name, self.connect_timeout).await.ok()?;
entry.introspection()
}
async fn get_value(&self, name: &str) -> Option<PvField> {
let entry = self.cache.lookup(name, self.connect_timeout).await.ok()?;
entry.snapshot()
}
async fn put_value(&self, name: &str, value: PvField) -> Result<(), String> {
let _entry = self
.cache
.lookup(name, self.connect_timeout)
.await
.map_err(|e| e.to_string())?;
let value_str = pvfield_to_pvput_string(&value)
.ok_or_else(|| "unsupported PvField shape for upstream PUT".to_string())?;
self.cache
.client()
.pvput(name, &value_str)
.await
.map_err(|e| e.to_string())
}
async fn put_value_checked(
&self,
checked: AccessChecked,
value: PvField,
ctx: ChannelContext,
) -> Result<(), String> {
if !checked.allows_write() {
tracing::debug!(
pv = %checked.pv_name(),
account = %ctx.account,
method = %ctx.method,
"pva-gateway: PUT denied by gateway ACF"
);
return Err(format!(
"PUT denied by gateway access security: \
PV '{pv}' from {host}/{account}/{method}",
pv = checked.pv_name(),
host = ctx.host,
account = ctx.account,
method = ctx.method,
));
}
let name = checked.pv_name();
let _entry = self
.cache
.lookup(name, self.connect_timeout)
.await
.map_err(|e| e.to_string())?;
let value_str = pvfield_to_pvput_string(&value)
.ok_or_else(|| "unsupported PvField shape for upstream PUT".to_string())?;
let client = self.upstream_client_for(&ctx);
tracing::debug!(
pv = %name,
account = %ctx.account,
method = %ctx.method,
"pva-gateway: forwarding PUT with downstream credentials"
);
client
.pvput(name, &value_str)
.await
.map_err(|e| e.to_string())
}
async fn is_writable(&self, name: &str) -> bool {
self.cache.peek(name).await.is_some()
}
async fn rpc(
&self,
name: &str,
request_desc: FieldDesc,
request_value: PvField,
) -> Result<(FieldDesc, PvField), String> {
let _entry = self
.cache
.lookup(name, self.connect_timeout)
.await
.map_err(|e| e.to_string())?;
let result = tokio::time::timeout(
self.rpc_timeout,
self.cache
.client()
.pvrpc(name, &request_desc, &request_value),
)
.await;
match result {
Ok(Ok(pair)) => Ok(pair),
Ok(Err(e)) => Err(e.to_string()),
Err(_) => Err(format!("upstream rpc timeout for {name}")),
}
}
async fn subscribe_raw(
&self,
name: &str,
) -> Option<mpsc::Receiver<epics_pva_rs::server_native::RawMonitorEvent>> {
if let Some(v) = epics_base_rs::runtime::env::get("EPICS_PVA_GW_RAW_FRAMES") {
if v.eq_ignore_ascii_case("NO") || v.eq_ignore_ascii_case("FALSE") || v == "0" {
return None;
}
}
let prev = self.subscriber_count.fetch_add(1, Ordering::Relaxed);
if prev >= self.max_subscribers {
self.subscriber_count.fetch_sub(1, Ordering::Relaxed);
tracing::warn!(
pv = %name,
live = prev,
cap = self.max_subscribers,
"pva-gateway: raw subscriber cap reached, refusing"
);
return None;
}
let entry = match self.cache.lookup(name, self.connect_timeout).await {
Ok(e) => e,
Err(_) => {
self.subscriber_count.fetch_sub(1, Ordering::Relaxed);
return None;
}
};
let mut bcast = entry.subscribe_raw();
let (mpsc_tx, mpsc_rx) =
mpsc::channel::<epics_pva_rs::server_native::RawMonitorEvent>(self.subscriber_queue);
let counter = self.subscriber_count.clone();
tokio::spawn(async move {
struct CounterGuard(std::sync::Arc<std::sync::atomic::AtomicUsize>);
impl Drop for CounterGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
}
}
let _guard = CounterGuard(counter);
loop {
match bcast.recv().await {
Ok(ev) => {
let out = epics_pva_rs::server_native::RawMonitorEvent {
body_bytes: ev.body,
byte_order: ev.byte_order,
};
if mpsc_tx.send(out).await.is_err() {
return;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => return,
}
}
});
Some(mpsc_rx)
}
async fn subscribe(&self, name: &str) -> Option<mpsc::Receiver<PvField>> {
let prev = self.subscriber_count.fetch_add(1, Ordering::Relaxed);
if prev >= self.max_subscribers {
self.subscriber_count.fetch_sub(1, Ordering::Relaxed);
tracing::warn!(
pv = %name,
live = prev,
cap = self.max_subscribers,
"pva-gateway: subscriber cap reached, refusing"
);
return None;
}
let entry = match self.cache.lookup(name, self.connect_timeout).await {
Ok(e) => e,
Err(_) => {
self.subscriber_count.fetch_sub(1, Ordering::Relaxed);
return None;
}
};
let mut bcast_rx = entry.subscribe();
let initial = entry.snapshot();
let (mpsc_tx, mpsc_rx) = mpsc::channel(self.subscriber_queue);
let counter = self.subscriber_count.clone();
tokio::spawn(async move {
struct CounterGuard(Arc<AtomicUsize>);
impl Drop for CounterGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
let _guard = CounterGuard(counter);
if let Some(v) = initial {
if mpsc_tx.send(v).await.is_err() {
return;
}
}
loop {
match bcast_rx.recv().await {
Ok(v) => {
if mpsc_tx.send(v).await.is_err() {
return;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => return,
}
}
});
Some(mpsc_rx)
}
fn notify_watermark_high(&self, name: &str) {
tracing::warn!(
pv = %name,
"pva-gateway: downstream monitor outbox crossed high watermark"
);
let cache = self.cache.clone();
let name_owned = name.to_string();
tokio::spawn(async move {
if let Some(entry) = cache.peek(&name_owned).await {
if let Some(p) = entry.pauser_snapshot() {
p.pause().await;
}
}
});
}
fn notify_watermark_low(&self, name: &str) {
tracing::debug!(
pv = %name,
"pva-gateway: downstream monitor outbox drained below low watermark"
);
let cache = self.cache.clone();
let name_owned = name.to_string();
tokio::spawn(async move {
if let Some(entry) = cache.peek(&name_owned).await {
if let Some(p) = entry.pauser_snapshot() {
p.resume().await;
}
}
});
}
}
fn pvfield_to_pvput_string(v: &PvField) -> Option<String> {
match v {
PvField::Scalar(sv) => Some(scalar_to_string(sv)),
PvField::ScalarArray(items) => {
let parts: Vec<String> = items.iter().map(scalar_to_string).collect();
Some(parts.join(" "))
}
PvField::Structure(s) => {
for (name, field) in &s.fields {
if name == "value" {
return pvfield_to_pvput_string(field);
}
}
None
}
PvField::Variant(boxed) => pvfield_to_pvput_string(&boxed.value),
PvField::Union {
selector, value, ..
} => {
if *selector < 0 {
None
} else {
pvfield_to_pvput_string(value)
}
}
_ => None,
}
}
fn scalar_to_string(sv: &epics_pva_rs::pvdata::ScalarValue) -> String {
use epics_pva_rs::pvdata::ScalarValue::*;
match sv {
Boolean(b) => {
if *b {
"1".into()
} else {
"0".into()
}
}
Byte(x) => x.to_string(),
UByte(x) => x.to_string(),
Short(x) => x.to_string(),
UShort(x) => x.to_string(),
Int(x) => x.to_string(),
UInt(x) => x.to_string(),
Long(x) => x.to_string(),
ULong(x) => x.to_string(),
Float(x) => x.to_string(),
Double(x) => x.to_string(),
String(s) => s.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use epics_base_rs::server::access_security::parse_acf;
fn make_ctx(host: &str, account: &str, method: &str) -> ChannelContext {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
ChannelContext {
peer: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0),
account: account.to_string(),
method: method.to_string(),
host: host.to_string(),
}
}
fn make_source() -> GatewayChannelSource {
let client = Arc::new(PvaClient::builder().build());
let cache = ChannelCache::new(client, Duration::from_secs(60));
GatewayChannelSource::new(cache)
}
#[tokio::test]
async fn acl_level_no_acf_is_readwrite() {
let src = make_source();
let level = src
.acl_level("any:pv", &make_ctx("h", "anyone", "anonymous"))
.await;
assert!(matches!(level, AccessLevel::ReadWrite));
}
#[tokio::test]
async fn put_value_ctx_denied_when_acf_no_write() {
let src = make_source();
let cfg = parse_acf(
r#"
UAG(admins) { admin }
ASG(DEFAULT) {
RULE(1, READ)
RULE(1, WRITE) { UAG(admins) }
}
"#,
)
.unwrap();
src.set_acf(Some(cfg)).await;
let dummy_value = PvField::Scalar(epics_pva_rs::pvdata::ScalarValue::Double(0.0));
let err = src
.put_value_ctx(
"any:pv",
dummy_value,
make_ctx("h", "intruder", "anonymous"),
)
.await
.expect_err("PUT must be denied for non-admin under DEFAULT ASG");
assert!(
err.contains("denied by gateway access security"),
"denial reason must name the gateway as enforcement point: {err:?}",
);
}
#[tokio::test]
async fn get_and_subscribe_denied_when_acf_no_access() {
let src = make_source();
let cfg = parse_acf(
r#"
UAG(ops) { alice }
ASG(DEFAULT) {
RULE(1, READ) { UAG(ops) }
}
"#,
)
.unwrap();
src.set_acf(Some(cfg)).await;
let ctx = make_ctx("h", "intruder", "anonymous");
assert!(
src.get_value_ctx("any:pv", ctx.clone()).await.is_none(),
"GET must be denied for non-ops"
);
assert!(
src.subscribe_ctx("any:pv", ctx).await.is_none(),
"MONITOR must be denied for non-ops"
);
}
#[tokio::test]
async fn acf_swap_takes_effect_on_next_op() {
let src = make_source();
let deny = parse_acf(r#"ASG(DEFAULT) { RULE(1, READ) }"#).unwrap();
src.set_acf(Some(deny)).await;
let dummy_value = PvField::Scalar(epics_pva_rs::pvdata::ScalarValue::Double(0.0));
let ctx = make_ctx("h", "anyone", "anonymous");
assert!(
src.put_value_ctx("any:pv", dummy_value.clone(), ctx.clone())
.await
.is_err(),
"initial deny-WRITE policy must reject PUT"
);
let permissive = parse_acf(r#"ASG(DEFAULT) { RULE(1, READ) RULE(1, WRITE) }"#).unwrap();
src.set_acf(Some(permissive)).await;
let result = src.put_value_ctx("any:pv", dummy_value, ctx).await;
if let Err(msg) = result {
assert!(
!msg.contains("denied by gateway access security"),
"post-swap PUT must NOT be ACL-denied: {msg:?}",
);
}
}
#[tokio::test]
async fn acl_level_uses_per_pv_asg_resolver() {
let src = make_source();
let cfg = parse_acf(
r#"
UAG(admins) { admin }
ASG(DEFAULT) {
RULE(1, READ)
RULE(1, WRITE)
}
ASG(OPERATOR) {
RULE(1, READ)
RULE(1, WRITE) { UAG(admins) }
}
ASG(LOCKED) {
RULE(1, READ)
}
"#,
)
.unwrap();
src.set_acf(Some(cfg)).await;
src.set_asg_resolver(Some(Arc::new(|pv: &str| {
if pv.starts_with("set:") {
"OPERATOR".to_string()
} else if pv.starts_with("dev:") {
"LOCKED".to_string()
} else {
"DEFAULT".to_string()
}
})))
.await;
let guest = make_ctx("anyhost", "guest", "anonymous");
let admin = make_ctx("anyhost", "admin", "anonymous");
assert_eq!(
src.acl_level("other:val", &guest).await,
AccessLevel::ReadWrite
);
assert_eq!(
src.acl_level("set:current", &guest).await,
AccessLevel::Read
);
assert_eq!(
src.acl_level("set:current", &admin).await,
AccessLevel::ReadWrite
);
assert_eq!(src.acl_level("dev:hwid", &admin).await, AccessLevel::Read);
assert_eq!(src.acl_level("dev:hwid", &guest).await, AccessLevel::Read);
}
#[tokio::test]
async fn asg_resolver_swap_takes_effect_on_next_acl_check() {
let src = make_source();
let cfg = parse_acf(
r#"
ASG(DEFAULT) {
RULE(1, READ)
RULE(1, WRITE)
}
ASG(LOCKED) {
RULE(1, READ)
}
"#,
)
.unwrap();
src.set_acf(Some(cfg)).await;
let ctx = make_ctx("h", "anyone", "anonymous");
assert_eq!(src.acl_level("X", &ctx).await, AccessLevel::ReadWrite);
src.set_asg_resolver(Some(Arc::new(|_pv| "LOCKED".to_string())))
.await;
assert_eq!(src.acl_level("X", &ctx).await, AccessLevel::Read);
src.set_asg_resolver(None).await;
assert_eq!(src.acl_level("X", &ctx).await, AccessLevel::ReadWrite);
}
#[tokio::test]
async fn rpc_ctx_denies_when_acf_no_access() {
use epics_pva_rs::pvdata::{FieldDesc, PvField};
let src = make_source();
let cfg = parse_acf(
r#"
UAG(ops) { alice }
ASG(DEFAULT) {
RULE(1, READ) { UAG(ops) }
}
"#,
)
.unwrap();
src.set_acf(Some(cfg)).await;
let result = src
.rpc_ctx(
"some:rpc",
FieldDesc::Variant,
PvField::Null,
make_ctx("anyhost", "intruder", "anonymous"),
)
.await;
let err = result.expect_err("RPC must be denied for NoAccess peer");
assert!(
err.contains("denied by gateway access security"),
"denial message must name the gateway as enforcer: {err:?}",
);
}
#[tokio::test]
async fn subscribe_raw_ctx_denies_when_acf_no_access() {
let src = make_source();
let cfg = parse_acf(
r#"
UAG(ops) { alice }
ASG(DEFAULT) {
RULE(1, READ) { UAG(ops) }
}
"#,
)
.unwrap();
src.set_acf(Some(cfg)).await;
let rx = src
.subscribe_raw_ctx("any:pv", make_ctx("anyhost", "intruder", "anonymous"))
.await;
assert!(
rx.is_none(),
"raw subscribe must be denied for a NoAccess peer"
);
}
}