use std::collections::HashMap;
use std::sync::Arc;
use epics_base_rs::server::database::PvDatabase;
use epics_pva_rs::pvdata::{FieldDesc, PvStructure};
use epics_base_rs::types::DbFieldType;
use super::channel::BridgeChannel;
use super::group::GroupChannel;
use super::group_config::GroupPvDef;
use super::pvif::NtType;
use crate::error::{BridgeError, BridgeResult};
pub trait AccessControl: Send + Sync {
fn can_read(&self, _channel: &str, _user: &str, _host: &str) -> bool {
true
}
fn can_write(&self, _channel: &str, _user: &str, _host: &str) -> bool {
true
}
}
pub struct AllowAllAccess;
impl AccessControl for AllowAllAccess {}
pub struct AcfAccessControl {
db: Arc<epics_base_rs::server::database::PvDatabase>,
cfg: Arc<epics_base_rs::server::access_security::AccessSecurityConfig>,
}
impl AcfAccessControl {
pub fn new(
db: Arc<epics_base_rs::server::database::PvDatabase>,
cfg: epics_base_rs::server::access_security::AccessSecurityConfig,
) -> Self {
Self {
db,
cfg: Arc::new(cfg),
}
}
fn level_for(&self, channel: &str, user: &str, host: &str) -> AccessLevelLite {
let asg = self.resolve_asg_blocking(channel);
let level = self
.cfg
.check_access_method(&asg, host, user, 0, "anonymous", "");
match level {
epics_base_rs::server::access_security::AccessLevel::ReadWrite => {
AccessLevelLite::ReadWrite
}
epics_base_rs::server::access_security::AccessLevel::Read => AccessLevelLite::Read,
_ => AccessLevelLite::None,
}
}
fn resolve_asg_blocking(&self, channel: &str) -> String {
let (record_name, _field) = epics_base_rs::server::database::parse_pv_name(channel);
let db = self.db.clone();
let name = record_name.to_string();
let lookup = async move {
if let Some(rec) = db.get_record(&name).await {
let inst = rec.read().await;
if !inst.common.asg.is_empty() {
return inst.common.asg.clone();
}
}
"DEFAULT".to_string()
};
match tokio::runtime::Handle::try_current() {
Ok(handle) => match handle.runtime_flavor() {
tokio::runtime::RuntimeFlavor::MultiThread => {
tokio::task::block_in_place(|| handle.block_on(lookup))
}
_ => "DEFAULT".to_string(),
},
Err(_) => "DEFAULT".to_string(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum AccessLevelLite {
None,
Read,
ReadWrite,
}
impl AccessControl for AcfAccessControl {
fn can_read(&self, channel: &str, user: &str, host: &str) -> bool {
self.level_for(channel, user, host) != AccessLevelLite::None
}
fn can_write(&self, channel: &str, user: &str, host: &str) -> bool {
self.level_for(channel, user, host) == AccessLevelLite::ReadWrite
}
}
#[derive(Clone)]
pub struct AccessContext {
pub access: Arc<dyn AccessControl>,
pub user: String,
pub host: String,
}
impl AccessContext {
pub fn anonymous(access: Arc<dyn AccessControl>) -> Self {
Self {
access,
user: String::new(),
host: String::new(),
}
}
pub fn with_identity(access: Arc<dyn AccessControl>, user: String, host: String) -> Self {
Self { access, user, host }
}
pub fn allow_all() -> Self {
Self::anonymous(Arc::new(AllowAllAccess))
}
pub fn can_read(&self, channel: &str) -> bool {
self.access.can_read(channel, &self.user, &self.host)
}
pub fn can_write(&self, channel: &str) -> bool {
self.access.can_write(channel, &self.user, &self.host)
}
}
impl Default for AccessContext {
fn default() -> Self {
Self::allow_all()
}
}
pub trait ChannelProvider: Send + Sync {
fn provider_name(&self) -> &str;
fn channel_find(&self, name: &str) -> impl std::future::Future<Output = bool> + Send;
fn channel_list(&self) -> impl std::future::Future<Output = Vec<String>> + Send;
fn create_channel(
&self,
name: &str,
) -> impl std::future::Future<Output = BridgeResult<AnyChannel>> + Send;
}
pub trait Channel: Send + Sync {
fn channel_name(&self) -> &str;
fn get(
&self,
request: &PvStructure,
) -> impl std::future::Future<Output = BridgeResult<PvStructure>> + Send;
fn put(
&self,
value: &PvStructure,
) -> impl std::future::Future<Output = BridgeResult<()>> + Send;
fn get_field(&self) -> impl std::future::Future<Output = BridgeResult<FieldDesc>> + Send;
fn create_monitor(
&self,
) -> impl std::future::Future<Output = BridgeResult<super::group::AnyMonitor>> + Send;
}
pub trait PvaMonitor: Send + Sync {
fn poll(&mut self) -> impl std::future::Future<Output = Option<PvStructure>> + Send;
fn start(&mut self) -> impl std::future::Future<Output = BridgeResult<()>> + Send;
fn stop(&mut self) -> impl std::future::Future<Output = ()> + Send;
}
pub enum AnyChannel {
Single(BridgeChannel),
Group(GroupChannel),
}
impl Channel for AnyChannel {
fn channel_name(&self) -> &str {
match self {
Self::Single(ch) => ch.channel_name(),
Self::Group(ch) => ch.channel_name(),
}
}
async fn get(&self, request: &PvStructure) -> BridgeResult<PvStructure> {
match self {
Self::Single(ch) => ch.get(request).await,
Self::Group(ch) => ch.get(request).await,
}
}
async fn put(&self, value: &PvStructure) -> BridgeResult<()> {
match self {
Self::Single(ch) => ch.put(value).await,
Self::Group(ch) => ch.put(value).await,
}
}
async fn get_field(&self) -> BridgeResult<FieldDesc> {
match self {
Self::Single(ch) => ch.get_field().await,
Self::Group(ch) => ch.get_field().await,
}
}
async fn create_monitor(&self) -> BridgeResult<super::group::AnyMonitor> {
match self {
Self::Single(ch) => ch.create_monitor().await,
Self::Group(ch) => ch.create_monitor().await,
}
}
}
pub struct BridgeProvider {
db: Arc<PvDatabase>,
groups: parking_lot::RwLock<HashMap<String, GroupPvDef>>,
channels_created: std::sync::atomic::AtomicU64,
ops_get: std::sync::atomic::AtomicU64,
ops_put: std::sync::atomic::AtomicU64,
ops_subscribe: std::sync::atomic::AtomicU64,
record_cache: tokio::sync::RwLock<HashMap<String, (NtType, DbFieldType)>>,
access_cell: Arc<parking_lot::RwLock<Arc<dyn AccessControl>>>,
}
struct LiveAccessProxy {
cell: Arc<parking_lot::RwLock<Arc<dyn AccessControl>>>,
}
impl AccessControl for LiveAccessProxy {
fn can_read(&self, channel: &str, user: &str, host: &str) -> bool {
self.cell.read().can_read(channel, user, host)
}
fn can_write(&self, channel: &str, user: &str, host: &str) -> bool {
self.cell.read().can_write(channel, user, host)
}
}
impl BridgeProvider {
pub fn new(db: Arc<PvDatabase>) -> Self {
Self {
db,
groups: parking_lot::RwLock::new(HashMap::new()),
record_cache: tokio::sync::RwLock::new(HashMap::new()),
access_cell: Arc::new(parking_lot::RwLock::new(Arc::new(AllowAllAccess))),
channels_created: std::sync::atomic::AtomicU64::new(0),
ops_get: std::sync::atomic::AtomicU64::new(0),
ops_put: std::sync::atomic::AtomicU64::new(0),
ops_subscribe: std::sync::atomic::AtomicU64::new(0),
}
}
pub async fn is_writable(&self, name: &str) -> bool {
if self.groups.read().contains_key(name) {
return true;
}
let (record, _field) = epics_base_rs::server::database::parse_pv_name(name);
let Some(rec_arc) = self.db.get_record(record).await else {
return false;
};
let inst = rec_arc.read().await;
!inst.common.disp
}
pub fn op_stats(&self) -> ProviderOpStats {
use std::sync::atomic::Ordering::Relaxed;
ProviderOpStats {
channels_created: self.channels_created.load(Relaxed),
gets: self.ops_get.load(Relaxed),
puts: self.ops_put.load(Relaxed),
subscribes: self.ops_subscribe.load(Relaxed),
}
}
pub fn note_channel_created(&self) {
self.channels_created
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn note_get(&self) {
self.ops_get
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn note_put(&self) {
self.ops_put
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn note_subscribe(&self) {
self.ops_subscribe
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
#[derive(Debug, Clone, Default)]
pub struct ProviderOpStats {
pub channels_created: u64,
pub gets: u64,
pub puts: u64,
pub subscribes: u64,
}
impl BridgeProvider {
pub fn set_access_control(&self, access: Arc<dyn AccessControl>) {
*self.access_cell.write() = access;
}
pub fn access_control(&self) -> Arc<dyn AccessControl> {
self.access_cell.read().clone()
}
pub fn live_access(&self) -> Arc<dyn AccessControl> {
Arc::new(LiveAccessProxy {
cell: self.access_cell.clone(),
})
}
pub fn can_write(&self, channel: &str, user: &str, host: &str) -> bool {
self.access_cell.read().can_write(channel, user, host)
}
pub fn can_read(&self, channel: &str, user: &str, host: &str) -> bool {
self.access_cell.read().can_read(channel, user, host)
}
pub fn load_group_config(&self, json: &str) -> BridgeResult<()> {
let defs = super::group_config::parse_group_config(json)?;
let mut g = self.groups.write();
for def in defs {
g.insert(def.name.clone(), def);
}
Ok(())
}
pub fn load_group_file(&self, path: &str) -> BridgeResult<()> {
let content = std::fs::read_to_string(path)?;
self.load_group_config(&content)
}
pub fn load_info_group(&self, record_name: &str, json: &str) -> BridgeResult<()> {
let defs = super::group_config::parse_info_group(record_name, json)?;
let mut g = self.groups.write();
super::group_config::merge_group_defs(&mut g, defs);
Ok(())
}
pub fn process_groups(&self) -> usize {
let g = self.groups.read();
let names: Vec<String> = g.keys().cloned().collect();
let mut finalized = 0;
for name in names {
let def = g.get(&name).cloned().unwrap();
let field_names: std::collections::HashSet<String> =
def.members.iter().map(|m| m.field_name.clone()).collect();
for member in &def.members {
if let super::group_config::TriggerDef::Fields(refs) = &member.triggers {
for r in refs {
if !field_names.contains(r) {
tracing::warn!(
group = %name,
member = %member.field_name,
trigger = %r,
"group trigger references unknown field"
);
}
}
}
}
finalized += 1;
}
finalized
}
pub fn database(&self) -> &Arc<PvDatabase> {
&self.db
}
pub fn groups(&self) -> HashMap<String, GroupPvDef> {
self.groups.read().clone()
}
pub fn group_count(&self) -> usize {
self.groups.read().len()
}
pub fn reset_groups(&self) -> usize {
let mut g = self.groups.write();
let n = g.len();
g.clear();
n
}
pub fn group_member(
&self,
group: &str,
field: &str,
) -> Option<(String, super::pvif::FieldMapping)> {
let g = self.groups.read();
let def = g.get(group)?;
let m = def.members.iter().find(|m| m.field_name == field)?;
Some((m.channel.clone(), m.mapping))
}
pub async fn get_group_field(
&self,
group: &str,
field: &str,
) -> Option<epics_base_rs::types::EpicsValue> {
let (channel, mapping) = self.group_member(group, field)?;
if matches!(
mapping,
super::pvif::FieldMapping::Structure | super::pvif::FieldMapping::Const
) {
return None;
}
self.db.get_pv(&channel).await.ok()
}
pub async fn put_group_field(
&self,
group: &str,
field: &str,
value: epics_base_rs::types::EpicsValue,
user: &str,
host: &str,
) -> BridgeResult<()> {
if !self.can_write(group, user, host) {
return Err(crate::error::BridgeError::PutRejected(format!(
"write denied for group {group} (user='{user}' host='{host}')"
)));
}
let (channel, mapping) = self
.group_member(group, field)
.ok_or_else(|| crate::error::BridgeError::RecordNotFound(format!("{group}.{field}")))?;
if matches!(
mapping,
super::pvif::FieldMapping::Structure | super::pvif::FieldMapping::Const
) {
return Err(crate::error::BridgeError::PutRejected(format!(
"{group}.{field}: Structure/Const members are not writable"
)));
}
self.db
.put_pv(&channel, value)
.await
.map_err(|e| crate::error::BridgeError::PutRejected(e.to_string()))
}
pub async fn clear_cache(&self) {
self.record_cache.write().await.clear();
}
}
impl ChannelProvider for BridgeProvider {
fn provider_name(&self) -> &str {
"BRIDGE"
}
async fn channel_find(&self, name: &str) -> bool {
if self.groups.read().contains_key(name) {
return true;
}
self.db.has_name(name).await
}
async fn channel_list(&self) -> Vec<String> {
let mut names = self.db.all_record_names().await;
names.extend(self.db.all_alias_names().await);
names.extend(self.groups.read().keys().cloned());
names.sort();
names
}
async fn create_channel(&self, name: &str) -> BridgeResult<AnyChannel> {
self.create_channel_for(name, "", "").await
}
}
impl BridgeProvider {
pub async fn create_channel_for(
&self,
name: &str,
user: &str,
host: &str,
) -> BridgeResult<AnyChannel> {
self.note_channel_created();
let access_ctx =
AccessContext::with_identity(self.live_access(), user.to_string(), host.to_string());
if let Some(def) = self.groups.read().get(name).cloned() {
return Ok(AnyChannel::Group(
GroupChannel::new(self.db.clone(), def).with_access(access_ctx),
));
}
let (record_name, _) = epics_base_rs::server::database::parse_pv_name(name);
{
let cache = self.record_cache.read().await;
if let Some(&(nt_type, value_dbf)) = cache.get(record_name) {
return Ok(AnyChannel::Single(
BridgeChannel::from_cached(
self.db.clone(),
record_name.to_string(),
nt_type,
value_dbf,
)
.with_access(access_ctx),
));
}
}
if self.db.has_name(name).await {
let channel = BridgeChannel::new(self.db.clone(), name).await?;
let mut cache = self.record_cache.write().await;
cache.insert(
record_name.to_string(),
(channel.nt_type(), channel.value_dbf()),
);
return Ok(AnyChannel::Single(channel.with_access(access_ctx)));
}
Err(BridgeError::ChannelNotFound(name.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
struct ReadOnly;
impl AccessControl for ReadOnly {
fn can_write(&self, _: &str, _: &str, _: &str) -> bool {
false
}
}
struct DenySpecific(String);
impl AccessControl for DenySpecific {
fn can_read(&self, channel: &str, _: &str, _: &str) -> bool {
channel != self.0
}
fn can_write(&self, channel: &str, _: &str, _: &str) -> bool {
channel != self.0
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn acf_access_control_gates_qsrv_channels() {
use epics_base_rs::server::access_security::parse_acf;
use epics_base_rs::server::database::PvDatabase;
use epics_base_rs::server::records::ai::AiRecord;
let acf_text = r#"
UAG(admins) { admin }
ASG(SECURE) {
RULE(1, READ)
RULE(1, WRITE) { UAG(admins) }
}
"#;
let cfg = parse_acf(acf_text).unwrap();
let db = Arc::new(PvDatabase::new());
db.add_record("AI:SEC", Box::new(AiRecord::new(0.0)))
.await
.unwrap();
let rec = db.get_record("AI:SEC").await.unwrap();
rec.write().await.common.asg = "SECURE".to_string();
let acl = AcfAccessControl::new(db.clone(), cfg);
assert!(acl.can_read("AI:SEC", "guest", "anywhere"));
assert!(acl.can_read("AI:SEC", "admin", "anywhere"));
assert!(acl.can_write("AI:SEC", "admin", "anywhere"));
assert!(!acl.can_write("AI:SEC", "guest", "anywhere"));
}
#[test]
fn access_context_allow_all() {
let ctx = AccessContext::allow_all();
assert!(ctx.can_read("ANY"));
assert!(ctx.can_write("ANY"));
}
#[test]
fn access_context_read_only() {
let ctx = AccessContext::anonymous(Arc::new(ReadOnly));
assert!(ctx.can_read("X"));
assert!(!ctx.can_write("X"));
}
#[test]
fn access_context_with_identity() {
let ctx =
AccessContext::with_identity(Arc::new(AllowAllAccess), "alice".into(), "host1".into());
assert_eq!(ctx.user, "alice");
assert_eq!(ctx.host, "host1");
}
#[test]
fn access_context_deny_specific() {
let ctx = AccessContext::anonymous(Arc::new(DenySpecific("SECRET".to_string())));
assert!(ctx.can_read("PUBLIC"));
assert!(!ctx.can_read("SECRET"));
assert!(ctx.can_write("PUBLIC"));
assert!(!ctx.can_write("SECRET"));
}
#[test]
fn provider_set_access_control() {
let db = Arc::new(PvDatabase::new());
let provider = BridgeProvider::new(db);
assert!(provider.can_read("X", "u", "h"));
assert!(provider.can_write("X", "u", "h"));
provider.set_access_control(Arc::new(ReadOnly));
assert!(provider.can_read("X", "u", "h"));
assert!(!provider.can_write("X", "u", "h"));
}
#[tokio::test]
async fn read_only_channel_blocks_writes() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(ReadOnly));
let ch = BridgeChannel::from_cached(
db,
"PROT".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(access);
let mut put_struct = PvStructure::new("epics:nt/NTScalar:1.0");
put_struct.fields.push((
"value".into(),
epics_pva_rs::pvdata::PvField::Scalar(epics_pva_rs::pvdata::ScalarValue::Double(2.0)),
));
let result = ch.put(&put_struct).await;
assert!(result.is_err(), "expected access denied");
let err = format!("{}", result.unwrap_err());
assert!(
err.contains("denied"),
"expected denial message, got: {err}"
);
}
#[tokio::test]
async fn deny_specific_channel_blocks_named() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(DenySpecific("BLOCKED".to_string())));
let ch = BridgeChannel::from_cached(
db.clone(),
"BLOCKED".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(access);
let req = PvStructure::new("");
let result = ch.get(&req).await;
assert!(result.is_err(), "expected read denied for BLOCKED");
let ok_access = AccessContext::anonymous(Arc::new(DenySpecific("BLOCKED".to_string())));
let ch2 = BridgeChannel::from_cached(
db,
"ALLOWED".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(ok_access);
let result = ch2.get(&req).await;
let err = format!("{:?}", result.unwrap_err());
assert!(
!err.contains("denied"),
"ALLOWED channel should pass access check, got: {err}"
);
}
struct WriteOnly;
impl AccessControl for WriteOnly {
fn can_read(&self, _: &str, _: &str, _: &str) -> bool {
false
}
}
#[tokio::test]
async fn create_monitor_blocks_when_read_denied() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(WriteOnly));
let ch = BridgeChannel::from_cached(
db,
"PROT".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(access);
let result = ch.create_monitor().await;
match result {
Ok(_) => panic!("expected monitor create denied, got Ok"),
Err(e) => {
let err = format!("{e}");
assert!(
err.contains("monitor create denied"),
"expected monitor denial message, got: {err}"
);
}
}
}
#[test]
fn live_access_proxy_observes_policy_swap() {
let db = Arc::new(PvDatabase::new());
let provider = BridgeProvider::new(db);
let ctx =
AccessContext::with_identity(provider.live_access(), "alice".into(), "host1".into());
assert!(ctx.can_read("ANY"));
assert!(ctx.can_write("ANY"));
provider.set_access_control(Arc::new(DenySpecific("SECRET".into())));
assert!(ctx.can_read("ALLOWED"));
assert!(!ctx.can_read("SECRET"), "swap must be observed live");
assert!(!ctx.can_write("SECRET"));
provider.set_access_control(Arc::new(ReadOnly));
assert!(ctx.can_read("X"));
assert!(
!ctx.can_write("X"),
"policy swap must take effect immediately"
);
provider.set_access_control(Arc::new(AllowAllAccess));
assert!(ctx.can_write("X"));
}
#[tokio::test]
async fn bridge_monitor_start_blocks_when_read_denied() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(WriteOnly));
let mut monitor = super::super::monitor::BridgeMonitor::new(
db,
"PROT".to_string(),
super::super::pvif::NtType::Scalar,
)
.with_access(access);
let result = monitor.start().await;
assert!(result.is_err(), "expected monitor start denied");
let err = format!("{}", result.unwrap_err());
assert!(
err.contains("monitor read denied"),
"expected start denial, got: {err}"
);
}
}