use std::sync::Arc;
use epics_pva_rs::pvdata::{FieldDesc, PvField, PvStructure, ScalarType, ScalarValue};
use epics_pva_rs::server_native::source::{AccessChecked, ChannelContext, ChannelSource};
use tokio::sync::mpsc;
use super::channel_cache::ChannelCache;
use super::source::GatewayChannelSource;
pub type CredentialCheck = Arc<dyn Fn(&ChannelContext) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct ControlSource {
prefix: String,
cache: Arc<ChannelCache>,
gateway_source: GatewayChannelSource,
credential_check: CredentialCheck,
acf_path: Option<String>,
}
impl ControlSource {
pub fn new(
prefix: impl Into<String>,
cache: Arc<ChannelCache>,
gateway_source: GatewayChannelSource,
) -> Self {
Self {
prefix: prefix.into(),
cache,
gateway_source,
credential_check: Arc::new(|_ctx| false),
acf_path: None,
}
}
pub fn with_credential_check(mut self, check: CredentialCheck) -> Self {
self.credential_check = check;
self
}
pub fn with_acf_path(mut self, path: impl Into<String>) -> Self {
self.acf_path = Some(path.into());
self
}
fn diag_pv_names(&self) -> [String; 4] {
[
format!("{}:cacheSize", self.prefix),
format!("{}:upstreamCount", self.prefix),
format!("{}:liveSubscribers", self.prefix),
format!("{}:report", self.prefix),
]
}
fn control_pv_names(&self) -> [String; 3] {
[
format!("{}:flush", self.prefix),
format!("{}:drop", self.prefix),
format!("{}:reload", self.prefix),
]
}
fn nt_scalar_long(v: i64) -> PvField {
let mut s = PvStructure::new("epics:nt/NTScalar:1.0");
s.fields
.push(("value".into(), PvField::Scalar(ScalarValue::Long(v))));
PvField::Structure(s)
}
fn nt_scalar_long_desc() -> FieldDesc {
FieldDesc::Structure {
struct_id: "epics:nt/NTScalar:1.0".into(),
fields: vec![("value".into(), FieldDesc::Scalar(ScalarType::Long))],
}
}
fn nt_scalar_string(v: String) -> PvField {
let mut s = PvStructure::new("epics:nt/NTScalar:1.0");
s.fields
.push(("value".into(), PvField::Scalar(ScalarValue::String(v))));
PvField::Structure(s)
}
fn nt_scalar_string_desc() -> FieldDesc {
FieldDesc::Structure {
struct_id: "epics:nt/NTScalar:1.0".into(),
fields: vec![("value".into(), FieldDesc::Scalar(ScalarType::String))],
}
}
fn control_reply_desc() -> FieldDesc {
FieldDesc::Structure {
struct_id: "epics:nt/NTScalar:1.0".into(),
fields: vec![
("value".into(), FieldDesc::Scalar(ScalarType::Long)),
("message".into(), FieldDesc::Scalar(ScalarType::String)),
],
}
}
fn control_reply(value: i64, message: impl Into<String>) -> PvField {
let mut s = PvStructure::new("epics:nt/NTScalar:1.0");
s.fields
.push(("value".into(), PvField::Scalar(ScalarValue::Long(value))));
s.fields.push((
"message".into(),
PvField::Scalar(ScalarValue::String(message.into())),
));
PvField::Structure(s)
}
fn is_diag(&self, name: &str) -> bool {
self.diag_pv_names().iter().any(|n| n == name)
}
fn is_control(&self, name: &str) -> bool {
self.control_pv_names().iter().any(|n| n == name)
}
fn matches(&self, name: &str) -> bool {
self.is_diag(name) || self.is_control(name)
}
fn rpc_string_arg(request_value: &PvField, arg: &str) -> Option<String> {
fn scalar_string(f: &PvField) -> Option<String> {
match f {
PvField::Scalar(ScalarValue::String(s)) => Some(s.clone()),
_ => None,
}
}
let PvField::Structure(root) = request_value else {
return None;
};
if let Some((_, PvField::Structure(query))) = root.fields.iter().find(|(n, _)| n == "query")
{
if let Some((_, f)) = query.fields.iter().find(|(n, _)| n == arg) {
if let Some(s) = scalar_string(f) {
return Some(s);
}
}
}
root.fields
.iter()
.find(|(n, _)| n == arg)
.and_then(|(_, f)| scalar_string(f))
}
async fn run_control_rpc(
&self,
name: &str,
request_value: &PvField,
) -> Result<(FieldDesc, PvField), String> {
let names = self.control_pv_names();
if name == names[0] {
let removed = self.cache.flush().await as i64;
tracing::info!(
gateway_control = %name,
removed,
"pva-gateway: operator flushed channel cache via RPC"
);
Ok((
Self::control_reply_desc(),
Self::control_reply(removed, format!("flushed {removed} cache entries")),
))
} else if name == names[1] {
let target = Self::rpc_string_arg(request_value, "pv").ok_or_else(|| {
"drop RPC requires a string 'pv' argument naming the cache entry".to_string()
})?;
if target.is_empty() {
return Err("drop RPC 'pv' argument must not be empty".to_string());
}
let dropped = self.cache.drop_entry(&target).await;
tracing::info!(
gateway_control = %name,
pv = %target,
dropped,
"pva-gateway: operator dropped cache entry via RPC"
);
let msg = if dropped {
format!("dropped cache entry '{target}'")
} else {
format!("cache entry '{target}' was not present")
};
Ok((
Self::control_reply_desc(),
Self::control_reply(i64::from(dropped), msg),
))
} else {
let path = Self::rpc_string_arg(request_value, "path")
.filter(|p| !p.is_empty())
.or_else(|| self.acf_path.clone())
.ok_or_else(|| {
"reload RPC requires a 'path' argument (no default ACF path \
configured on this gateway)"
.to_string()
})?;
let content = tokio::fs::read_to_string(&path)
.await
.map_err(|e| format!("reload: cannot read ACF file '{path}': {e}"))?;
let cfg = epics_base_rs::server::access_security::parse_acf(&content)
.map_err(|e| format!("reload: cannot parse ACF file '{path}': {e}"))?;
self.gateway_source.set_acf(Some(cfg)).await;
tracing::info!(
gateway_control = %name,
acf_path = %path,
"pva-gateway: operator reloaded ACF policy via RPC"
);
Ok((
Self::control_reply_desc(),
Self::control_reply(0, format!("reloaded ACF policy from '{path}'")),
))
}
}
}
impl ChannelSource for ControlSource {
async fn list_pvs(&self) -> Vec<String> {
let mut out = self.diag_pv_names().to_vec();
out.extend(self.control_pv_names());
out
}
async fn has_pv(&self, name: &str) -> bool {
self.matches(name)
}
async fn get_introspection(&self, name: &str) -> Option<FieldDesc> {
if self.is_control(name) {
return Some(Self::control_reply_desc());
}
if !self.is_diag(name) {
return None;
}
if name.ends_with(":report") {
Some(Self::nt_scalar_string_desc())
} else {
Some(Self::nt_scalar_long_desc())
}
}
async fn get_value(&self, name: &str) -> Option<PvField> {
if self.is_control(name) {
return Some(Self::control_reply(
0,
"control PV — invoke via RPC (pvcall), not GET/PUT",
));
}
if !self.is_diag(name) {
return None;
}
let cache_size = self.cache.entry_count().await as i64;
let live_subs = self.gateway_source.live_subscribers() as i64;
if name.ends_with(":cacheSize") || name.ends_with(":upstreamCount") {
Some(Self::nt_scalar_long(cache_size))
} else if name.ends_with(":liveSubscribers") {
Some(Self::nt_scalar_long(live_subs))
} else if name.ends_with(":report") {
let report = format!(
"cacheSize={cache_size} upstreamCount={cache_size} liveSubscribers={live_subs}"
);
Some(Self::nt_scalar_string(report))
} else {
None
}
}
async fn is_writable(&self, _name: &str) -> bool {
false
}
async fn put_value(&self, name: &str, _value: PvField) -> Result<(), String> {
if self.is_control(name) {
Err(format!(
"control PV '{name}' is invoked via RPC (pvcall), not PUT"
))
} else {
Err("control PVs are read-only".to_string())
}
}
async fn process(&self, name: &str) -> Result<(), String> {
if self.is_control(name) {
Err(format!(
"control PV '{name}' is invoked via RPC (pvcall), not PROCESS"
))
} else if self.is_diag(name) {
Err(format!(
"'{name}' is a read-only diagnostic PV — PROCESS not supported"
))
} else {
Err(format!("unknown control PV '{name}'"))
}
}
async fn rpc(
&self,
name: &str,
_request_desc: FieldDesc,
_request_value: PvField,
) -> Result<(FieldDesc, PvField), String> {
if self.is_control(name) {
Err(format!(
"control RPC '{name}' requires an authenticated request"
))
} else if self.is_diag(name) {
Err(format!("'{name}' is a read-only diagnostic PV, not an RPC"))
} else {
Err(format!("unknown control PV '{name}'"))
}
}
async fn rpc_checked(
&self,
checked: AccessChecked,
request_desc: FieldDesc,
request_value: PvField,
ctx: ChannelContext,
) -> Result<(FieldDesc, PvField), String> {
let name = checked.pv_name().to_string();
if !self.is_control(&name) {
return self.rpc(&name, request_desc, request_value).await;
}
if !(self.credential_check)(&ctx) {
tracing::warn!(
gateway_control = %name,
account = %ctx.account,
method = %ctx.method,
host = %ctx.host,
"pva-gateway: control RPC denied — credential check failed"
);
return Err(format!(
"control RPC '{name}' denied: {account}/{method} from {host} \
is not an authorised gateway operator",
account = ctx.account,
method = ctx.method,
host = ctx.host,
));
}
self.run_control_rpc(&name, &request_value).await
}
async fn subscribe(&self, name: &str) -> Option<mpsc::Receiver<PvField>> {
if !self.is_diag(name) {
return None;
}
let (tx, rx) = mpsc::channel::<PvField>(4);
let me = self.clone();
let pv_name = name.to_string();
tokio::spawn(async move {
let mut tick = tokio::time::interval(std::time::Duration::from_secs(1));
tick.tick().await; let mut last: Option<PvField> = None;
loop {
tick.tick().await;
let snapshot = me.get_value(&pv_name).await;
if let Some(value) = snapshot {
let changed = match &last {
Some(prev) => prev != &value,
None => true,
};
if changed {
if tx.send(value.clone()).await.is_err() {
break;
}
last = Some(value);
}
}
}
});
Some(rx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pva_gateway::channel_cache::DEFAULT_CLEANUP_INTERVAL;
use epics_pva_rs::client::PvaClient;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
fn make_source() -> (Arc<ChannelCache>, GatewayChannelSource) {
let client = Arc::new(PvaClient::builder().build());
let cache = ChannelCache::new(client, DEFAULT_CLEANUP_INTERVAL);
let gw = GatewayChannelSource::new(cache.clone());
(cache, gw)
}
fn ctx(account: &str, method: &str) -> ChannelContext {
ChannelContext {
peer: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5075),
account: account.into(),
method: method.into(),
host: "localhost".into(),
authority: String::new(),
roles: Vec::new(),
pv_request: None,
}
}
fn nturi_with_arg(arg: &str, value: &str) -> PvField {
let mut query = PvStructure::new("");
query.fields.push((
arg.into(),
PvField::Scalar(ScalarValue::String(value.into())),
));
let mut root = PvStructure::new("epics:nt/NTURI:1.0");
root.fields
.push(("query".into(), PvField::Structure(query)));
PvField::Structure(root)
}
fn reply_message(reply: &PvField) -> String {
let PvField::Structure(s) = reply else {
panic!("reply not a structure");
};
match s.fields.iter().find(|(n, _)| n == "message") {
Some((_, PvField::Scalar(ScalarValue::String(m)))) => m.clone(),
_ => panic!("reply has no message field"),
}
}
fn reply_value(reply: &PvField) -> i64 {
let PvField::Structure(s) = reply else {
panic!("reply not a structure");
};
match s.fields.iter().find(|(n, _)| n == "value") {
Some((_, PvField::Scalar(ScalarValue::Long(v)))) => *v,
_ => panic!("reply has no value field"),
}
}
async fn checked(pv: &str) -> AccessChecked {
epics_base_rs::server::access_security::AccessGate::open()
.check(pv, "localhost", "ops", "ca", "")
.await
}
#[tokio::test]
async fn control_process_is_rejected() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw);
let err = ctrl
.process("gw:flush")
.await
.expect_err("PROCESS of a control PV must be rejected");
assert!(err.contains("not PROCESS"), "control PV reason: {err:?}");
let err = ctrl
.process("gw:cacheSize")
.await
.expect_err("PROCESS of a diagnostic PV must be rejected");
assert!(
err.contains("PROCESS not supported"),
"diag PV reason: {err:?}"
);
let err = ctrl
.process_checked(checked("gw:flush").await, ctx("ops", "ca"))
.await
.expect_err("process_checked must reject even with a WRITE token");
assert!(
err.contains("not PROCESS"),
"process_checked reason: {err:?}"
);
}
#[tokio::test]
async fn control_rpc_denied_without_credential_predicate() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw);
let res = ctrl
.rpc_checked(
checked("gw:flush").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
PvField::Structure(PvStructure::new("")),
ctx("ops", "ca"),
)
.await;
assert!(res.is_err(), "deny-all default must reject control RPC");
assert!(
res.unwrap_err()
.contains("not an authorised gateway operator")
);
}
#[tokio::test]
async fn control_rpc_ctxless_path_always_refused() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw).with_credential_check(Arc::new(|_| true));
let res = ctrl
.rpc(
"gw:flush",
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
PvField::Structure(PvStructure::new("")),
)
.await;
assert!(res.is_err());
assert!(
res.unwrap_err()
.contains("requires an authenticated request")
);
}
#[tokio::test]
async fn flush_rpc_clears_cache_for_authorised_operator() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw)
.with_credential_check(Arc::new(|c| c.account == "ops"));
let (desc, reply) = ctrl
.rpc_checked(
checked("gw:flush").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
PvField::Structure(PvStructure::new("")),
ctx("ops", "ca"),
)
.await
.expect("authorised operator flush must succeed");
assert!(matches!(desc, FieldDesc::Structure { .. }));
assert_eq!(reply_value(&reply), 0);
assert!(reply_message(&reply).contains("flushed"));
}
#[tokio::test]
async fn flush_rpc_denied_for_unlisted_account() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw)
.with_credential_check(Arc::new(|c| c.account == "ops"));
let res = ctrl
.rpc_checked(
checked("gw:flush").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
PvField::Structure(PvStructure::new("")),
ctx("intruder", "ca"),
)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn drop_rpc_requires_pv_argument() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw).with_credential_check(Arc::new(|_| true));
let res = ctrl
.rpc_checked(
checked("gw:drop").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
PvField::Structure(PvStructure::new("")),
ctx("ops", "ca"),
)
.await;
assert!(res.is_err());
assert!(res.unwrap_err().contains("'pv' argument"));
}
#[tokio::test]
async fn drop_rpc_reports_missing_entry() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw).with_credential_check(Arc::new(|_| true));
let (_desc, reply) = ctrl
.rpc_checked(
checked("gw:drop").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
nturi_with_arg("pv", "NO:SUCH:PV"),
ctx("ops", "ca"),
)
.await
.expect("drop of an absent entry still returns a reply");
assert_eq!(reply_value(&reply), 0);
assert!(reply_message(&reply).contains("was not present"));
}
#[tokio::test]
async fn reload_rpc_without_path_or_default_fails() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw).with_credential_check(Arc::new(|_| true));
let res = ctrl
.rpc_checked(
checked("gw:reload").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
PvField::Structure(PvStructure::new("")),
ctx("ops", "ca"),
)
.await;
assert!(res.is_err());
assert!(res.unwrap_err().contains("'path' argument"));
}
#[tokio::test]
async fn reload_rpc_parses_acf_from_explicit_path() {
let (cache, gw) = make_source();
let ctrl =
ControlSource::new("gw", cache.clone(), gw).with_credential_check(Arc::new(|_| true));
let dir = std::env::temp_dir();
let path = dir.join(format!("pva_gw_b6_reload_{}.acf", std::process::id()));
std::fs::write(
&path,
"ASG(DEFAULT) {\n RULE(1, READ)\n RULE(1, WRITE)\n}\n",
)
.unwrap();
let (_desc, reply) = ctrl
.rpc_checked(
checked("gw:reload").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
nturi_with_arg("path", path.to_str().unwrap()),
ctx("ops", "ca"),
)
.await
.expect("reload of a valid ACF must succeed");
assert!(reply_message(&reply).contains("reloaded ACF policy"));
let _ = std::fs::remove_file(&path);
}
#[tokio::test]
async fn reload_rpc_uses_configured_default_path() {
let (cache, gw) = make_source();
let dir = std::env::temp_dir();
let path = dir.join(format!("pva_gw_b6_default_{}.acf", std::process::id()));
std::fs::write(&path, "ASG(DEFAULT) {\n RULE(1, READ)\n}\n").unwrap();
let ctrl = ControlSource::new("gw", cache, gw)
.with_credential_check(Arc::new(|_| true))
.with_acf_path(path.to_str().unwrap());
let res = ctrl
.rpc_checked(
checked("gw:reload").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
PvField::Structure(PvStructure::new("")),
ctx("ops", "ca"),
)
.await;
assert!(res.is_ok(), "reload must use the configured default path");
let _ = std::fs::remove_file(&path);
}
#[tokio::test]
async fn reload_rpc_rejects_unparseable_acf() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw).with_credential_check(Arc::new(|_| true));
let dir = std::env::temp_dir();
let path = dir.join(format!("pva_gw_b6_bad_{}.acf", std::process::id()));
std::fs::write(&path, "this is not valid ACF (((").unwrap();
let res = ctrl
.rpc_checked(
checked("gw:reload").await,
FieldDesc::Structure {
struct_id: String::new(),
fields: vec![],
},
nturi_with_arg("path", path.to_str().unwrap()),
ctx("ops", "ca"),
)
.await;
assert!(res.is_err());
let _ = std::fs::remove_file(&path);
}
#[tokio::test]
async fn diagnostic_pvs_remain_read_only() {
let (cache, gw) = make_source();
let ctrl = ControlSource::new("gw", cache, gw);
assert!(ctrl.has_pv("gw:cacheSize").await);
assert!(ctrl.has_pv("gw:report").await);
assert!(!ctrl.is_writable("gw:cacheSize").await);
assert!(ctrl.get_value("gw:cacheSize").await.is_some());
let names = ctrl.list_pvs().await;
assert!(names.contains(&"gw:cacheSize".to_string()));
assert!(names.contains(&"gw:flush".to_string()));
assert!(names.contains(&"gw:drop".to_string()));
assert!(names.contains(&"gw:reload".to_string()));
}
}