use std::collections::HashMap;
use std::sync::Arc;
use epics_base_rs::server::database::PvDatabase;
use epics_pva_rs::pvdata::{FieldDesc, PvStructure};
use epics_base_rs::types::DbFieldType;
use super::channel::BridgeChannel;
use super::group::GroupChannel;
use super::group_config::GroupPvDef;
use super::pvif::NtType;
use crate::error::{BridgeError, BridgeResult};
#[derive(Debug, Clone, Default)]
pub struct ClientCreds {
pub user: String,
pub host: String,
pub method: String,
pub authority: String,
pub roles: Vec<String>,
}
pub trait AccessControl: Send + Sync {
fn can_read(&self, _channel: &str, _user: &str, _host: &str) -> bool {
true
}
fn can_write(&self, _channel: &str, _user: &str, _host: &str) -> bool {
true
}
fn can_read_creds(&self, channel: &str, creds: &ClientCreds) -> bool {
self.can_read(channel, &creds.user, &creds.host)
}
fn can_write_creds(&self, channel: &str, creds: &ClientCreds) -> bool {
self.can_write(channel, &creds.user, &creds.host)
}
}
pub struct AllowAllAccess;
impl AccessControl for AllowAllAccess {}
pub struct AcfAccessControl {
db: Arc<epics_base_rs::server::database::PvDatabase>,
cfg: Arc<epics_base_rs::server::access_security::AccessSecurityConfig>,
}
impl AcfAccessControl {
pub fn new(
db: Arc<epics_base_rs::server::database::PvDatabase>,
cfg: epics_base_rs::server::access_security::AccessSecurityConfig,
) -> Self {
Self {
db,
cfg: Arc::new(cfg),
}
}
fn resolve_asg_and_asl_blocking(&self, channel: &str) -> (String, u8) {
let (record_name, _field) = epics_base_rs::server::database::parse_pv_name(channel);
let db = self.db.clone();
let name = record_name.to_string();
let lookup = async move {
if let Some(rec) = db.get_record(&name).await {
let inst = rec.read().await;
let asg = if inst.common.asg.is_empty() {
"DEFAULT".to_string()
} else {
inst.common.asg.clone()
};
return (asg, inst.common.asl);
}
("DEFAULT".to_string(), 0u8)
};
match tokio::runtime::Handle::try_current() {
Ok(handle) => match handle.runtime_flavor() {
tokio::runtime::RuntimeFlavor::MultiThread => {
tokio::task::block_in_place(|| handle.block_on(lookup))
}
_ => ("DEFAULT".to_string(), 0u8),
},
Err(_) => ("DEFAULT".to_string(), 0u8),
}
}
fn credential_strings(creds: &ClientCreds) -> Vec<String> {
let mut v = Vec::new();
let primary = if creds.method == "ca" || creds.method.is_empty() {
let pos = creds.user.rfind('/').map(|p| p + 1).unwrap_or(0);
creds.user[pos..].to_string()
} else {
format!("{}/{}", creds.method, creds.user)
};
v.push(primary);
for role in &creds.roles {
v.push(format!("role/{role}"));
}
v
}
fn level_for_creds(&self, channel: &str, creds: &ClientCreds) -> AccessLevelLite {
use epics_base_rs::server::access_security::AccessLevel;
let (asg, asl) = self.resolve_asg_and_asl_blocking(channel);
let cred_strings = Self::credential_strings(creds);
let method = if creds.method.is_empty() {
"anonymous"
} else {
creds.method.as_str()
};
let mut best = AccessLevelLite::None;
for cred_user in &cred_strings {
let lvl = self.cfg.check_access_method(
&asg,
&creds.host,
cred_user,
asl,
method,
&creds.authority,
);
let lit = match lvl {
AccessLevel::ReadWrite => AccessLevelLite::ReadWrite,
AccessLevel::Read => AccessLevelLite::Read,
_ => AccessLevelLite::None,
};
if lit == AccessLevelLite::ReadWrite {
return lit;
}
if lit == AccessLevelLite::Read && best == AccessLevelLite::None {
best = lit;
}
}
best
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum AccessLevelLite {
None,
Read,
ReadWrite,
}
impl AccessControl for AcfAccessControl {
fn can_read(&self, channel: &str, user: &str, host: &str) -> bool {
self.can_read_creds(
channel,
&ClientCreds {
user: user.to_string(),
host: host.to_string(),
..Default::default()
},
)
}
fn can_write(&self, channel: &str, user: &str, host: &str) -> bool {
self.can_write_creds(
channel,
&ClientCreds {
user: user.to_string(),
host: host.to_string(),
..Default::default()
},
)
}
fn can_read_creds(&self, channel: &str, creds: &ClientCreds) -> bool {
self.level_for_creds(channel, creds) != AccessLevelLite::None
}
fn can_write_creds(&self, channel: &str, creds: &ClientCreds) -> bool {
self.level_for_creds(channel, creds) == AccessLevelLite::ReadWrite
}
}
#[derive(Clone)]
pub struct AccessContext {
pub access: Arc<dyn AccessControl>,
pub user: String,
pub host: String,
pub method: String,
pub authority: String,
pub roles: Vec<String>,
}
impl AccessContext {
pub fn anonymous(access: Arc<dyn AccessControl>) -> Self {
Self {
access,
user: String::new(),
host: String::new(),
method: String::new(),
authority: String::new(),
roles: Vec::new(),
}
}
pub fn with_identity(access: Arc<dyn AccessControl>, user: String, host: String) -> Self {
Self {
access,
user,
host,
method: String::new(),
authority: String::new(),
roles: Vec::new(),
}
}
pub fn with_creds(access: Arc<dyn AccessControl>, creds: ClientCreds) -> Self {
Self {
access,
user: creds.user,
host: creds.host,
method: creds.method,
authority: creds.authority,
roles: creds.roles,
}
}
pub fn allow_all() -> Self {
Self::anonymous(Arc::new(AllowAllAccess))
}
pub fn can_read(&self, channel: &str) -> bool {
self.access.can_read_creds(channel, &self.to_client_creds())
}
pub fn can_write(&self, channel: &str) -> bool {
self.access
.can_write_creds(channel, &self.to_client_creds())
}
fn to_client_creds(&self) -> ClientCreds {
ClientCreds {
user: self.user.clone(),
host: self.host.clone(),
method: self.method.clone(),
authority: self.authority.clone(),
roles: self.roles.clone(),
}
}
}
impl Default for AccessContext {
fn default() -> Self {
Self::allow_all()
}
}
pub trait ChannelProvider: Send + Sync {
fn provider_name(&self) -> &str;
fn channel_find(&self, name: &str) -> impl std::future::Future<Output = bool> + Send;
fn channel_list(&self) -> impl std::future::Future<Output = Vec<String>> + Send;
fn create_channel(
&self,
name: &str,
) -> impl std::future::Future<Output = BridgeResult<AnyChannel>> + Send;
}
pub trait Channel: Send + Sync {
fn channel_name(&self) -> &str;
fn get(
&self,
request: &PvStructure,
) -> impl std::future::Future<Output = BridgeResult<PvStructure>> + Send;
fn put(
&self,
value: &PvStructure,
) -> impl std::future::Future<Output = BridgeResult<()>> + Send;
fn get_field(&self) -> impl std::future::Future<Output = BridgeResult<FieldDesc>> + Send;
fn create_monitor(
&self,
) -> impl std::future::Future<Output = BridgeResult<super::group::AnyMonitor>> + Send;
}
pub trait PvaMonitor: Send + Sync {
fn poll(&mut self) -> impl std::future::Future<Output = Option<PvStructure>> + Send;
fn start(&mut self) -> impl std::future::Future<Output = BridgeResult<()>> + Send;
fn stop(&mut self) -> impl std::future::Future<Output = ()> + Send;
}
pub enum AnyChannel {
Single(BridgeChannel),
Group(GroupChannel),
}
impl Channel for AnyChannel {
fn channel_name(&self) -> &str {
match self {
Self::Single(ch) => ch.channel_name(),
Self::Group(ch) => ch.channel_name(),
}
}
async fn get(&self, request: &PvStructure) -> BridgeResult<PvStructure> {
match self {
Self::Single(ch) => ch.get(request).await,
Self::Group(ch) => ch.get(request).await,
}
}
async fn put(&self, value: &PvStructure) -> BridgeResult<()> {
match self {
Self::Single(ch) => ch.put(value).await,
Self::Group(ch) => ch.put(value).await,
}
}
async fn get_field(&self) -> BridgeResult<FieldDesc> {
match self {
Self::Single(ch) => ch.get_field().await,
Self::Group(ch) => ch.get_field().await,
}
}
async fn create_monitor(&self) -> BridgeResult<super::group::AnyMonitor> {
match self {
Self::Single(ch) => ch.create_monitor().await,
Self::Group(ch) => ch.create_monitor().await,
}
}
}
pub struct BridgeProvider {
db: Arc<PvDatabase>,
groups: parking_lot::RwLock<HashMap<String, GroupPvDef>>,
channels_created: std::sync::atomic::AtomicU64,
ops_get: std::sync::atomic::AtomicU64,
ops_put: std::sync::atomic::AtomicU64,
ops_subscribe: std::sync::atomic::AtomicU64,
record_cache: tokio::sync::RwLock<HashMap<String, (NtType, DbFieldType)>>,
access_cell: Arc<parking_lot::RwLock<Arc<dyn AccessControl>>>,
}
struct LiveAccessProxy {
cell: Arc<parking_lot::RwLock<Arc<dyn AccessControl>>>,
}
impl AccessControl for LiveAccessProxy {
fn can_read(&self, channel: &str, user: &str, host: &str) -> bool {
self.cell.read().can_read(channel, user, host)
}
fn can_write(&self, channel: &str, user: &str, host: &str) -> bool {
self.cell.read().can_write(channel, user, host)
}
fn can_read_creds(&self, channel: &str, creds: &ClientCreds) -> bool {
self.cell.read().can_read_creds(channel, creds)
}
fn can_write_creds(&self, channel: &str, creds: &ClientCreds) -> bool {
self.cell.read().can_write_creds(channel, creds)
}
}
impl BridgeProvider {
pub fn new(db: Arc<PvDatabase>) -> Self {
Self {
db,
groups: parking_lot::RwLock::new(HashMap::new()),
record_cache: tokio::sync::RwLock::new(HashMap::new()),
access_cell: Arc::new(parking_lot::RwLock::new(Arc::new(AllowAllAccess))),
channels_created: std::sync::atomic::AtomicU64::new(0),
ops_get: std::sync::atomic::AtomicU64::new(0),
ops_put: std::sync::atomic::AtomicU64::new(0),
ops_subscribe: std::sync::atomic::AtomicU64::new(0),
}
}
pub async fn is_writable(&self, name: &str) -> bool {
if self.groups.read().contains_key(name) {
return true;
}
let (record, _field) = epics_base_rs::server::database::parse_pv_name(name);
let Some(rec_arc) = self.db.get_record(record).await else {
return false;
};
let inst = rec_arc.read().await;
!inst.common.disp
}
pub fn op_stats(&self) -> ProviderOpStats {
use std::sync::atomic::Ordering::Relaxed;
ProviderOpStats {
channels_created: self.channels_created.load(Relaxed),
gets: self.ops_get.load(Relaxed),
puts: self.ops_put.load(Relaxed),
subscribes: self.ops_subscribe.load(Relaxed),
}
}
pub fn note_channel_created(&self) {
self.channels_created
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn note_get(&self) {
self.ops_get
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn note_put(&self) {
self.ops_put
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn note_subscribe(&self) {
self.ops_subscribe
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
#[derive(Debug, Clone, Default)]
pub struct ProviderOpStats {
pub channels_created: u64,
pub gets: u64,
pub puts: u64,
pub subscribes: u64,
}
impl BridgeProvider {
pub fn set_access_control(&self, access: Arc<dyn AccessControl>) {
*self.access_cell.write() = access;
}
pub fn access_control(&self) -> Arc<dyn AccessControl> {
self.access_cell.read().clone()
}
pub fn live_access(&self) -> Arc<dyn AccessControl> {
Arc::new(LiveAccessProxy {
cell: self.access_cell.clone(),
})
}
pub fn can_write(&self, channel: &str, user: &str, host: &str) -> bool {
self.access_cell.read().can_write(channel, user, host)
}
pub fn can_read(&self, channel: &str, user: &str, host: &str) -> bool {
self.access_cell.read().can_read(channel, user, host)
}
pub fn load_group_config(&self, json: &str) -> BridgeResult<()> {
let defs = super::group_config::parse_group_config(json)?;
let mut g = self.groups.write();
for def in defs {
g.insert(def.name.clone(), def);
}
Ok(())
}
pub fn load_group_file(&self, path: &str) -> BridgeResult<()> {
let content = std::fs::read_to_string(path)?;
self.load_group_config(&content)
}
pub fn load_info_group(&self, record_name: &str, json: &str) -> BridgeResult<()> {
let defs = super::group_config::parse_info_group(record_name, json)?;
let mut g = self.groups.write();
super::group_config::merge_group_defs(&mut g, defs);
Ok(())
}
pub fn process_groups(&self) -> usize {
let g = self.groups.read();
let names: Vec<String> = g.keys().cloned().collect();
let mut finalized = 0;
for name in names {
let def = g.get(&name).cloned().unwrap();
let field_names: std::collections::HashSet<String> =
def.members.iter().map(|m| m.field_name.clone()).collect();
for member in &def.members {
if let super::group_config::TriggerDef::Fields(refs) = &member.triggers {
for r in refs {
if !field_names.contains(r) {
tracing::warn!(
group = %name,
member = %member.field_name,
trigger = %r,
"group trigger references unknown field"
);
}
}
}
}
finalized += 1;
}
finalized
}
pub fn database(&self) -> &Arc<PvDatabase> {
&self.db
}
pub fn groups(&self) -> HashMap<String, GroupPvDef> {
self.groups.read().clone()
}
pub fn group_count(&self) -> usize {
self.groups.read().len()
}
pub fn has_group_pv(&self, name: &str) -> bool {
self.groups.read().contains_key(name)
}
pub fn group_is_pure_self_trigger(&self, name: &str) -> bool {
self.groups
.read()
.get(name)
.map(|g| g.is_pure_self_trigger())
.unwrap_or(false)
}
pub async fn hosts_pv(&self, name: &str) -> bool {
self.channel_find(name).await
}
pub fn reset_groups(&self) -> usize {
let mut g = self.groups.write();
let n = g.len();
g.clear();
n
}
pub fn group_member(
&self,
group: &str,
field: &str,
) -> Option<(String, super::pvif::FieldMapping)> {
let g = self.groups.read();
let def = g.get(group)?;
let m = def.members.iter().find(|m| m.field_name == field)?;
Some((m.channel.clone(), m.mapping))
}
pub async fn get_group_field(
&self,
group: &str,
field: &str,
) -> Option<epics_base_rs::types::EpicsValue> {
let (channel, mapping) = self.group_member(group, field)?;
if matches!(
mapping,
super::pvif::FieldMapping::Structure | super::pvif::FieldMapping::Const
) {
return None;
}
self.db.get_pv(&channel).await.ok()
}
pub async fn put_group_field(
&self,
group: &str,
field: &str,
value: epics_base_rs::types::EpicsValue,
user: &str,
host: &str,
) -> BridgeResult<()> {
if !self.can_write(group, user, host) {
return Err(crate::error::BridgeError::PutRejected(format!(
"write denied for group {group} (user='{user}' host='{host}')"
)));
}
let (channel, mapping) = self
.group_member(group, field)
.ok_or_else(|| crate::error::BridgeError::RecordNotFound(format!("{group}.{field}")))?;
if matches!(
mapping,
super::pvif::FieldMapping::Structure | super::pvif::FieldMapping::Const
) {
return Err(crate::error::BridgeError::PutRejected(format!(
"{group}.{field}: Structure/Const members are not writable"
)));
}
self.db
.put_pv(&channel, value)
.await
.map_err(|e| crate::error::BridgeError::PutRejected(e.to_string()))
}
pub async fn clear_cache(&self) {
self.record_cache.write().await.clear();
}
}
impl ChannelProvider for BridgeProvider {
fn provider_name(&self) -> &str {
"BRIDGE"
}
async fn channel_find(&self, name: &str) -> bool {
if self.groups.read().contains_key(name) {
return true;
}
self.db.has_name(name).await
}
async fn channel_list(&self) -> Vec<String> {
let mut names = self.db.all_record_names().await;
names.extend(self.db.all_alias_names().await);
names.extend(self.groups.read().keys().cloned());
names.sort();
names
}
async fn create_channel(&self, name: &str) -> BridgeResult<AnyChannel> {
self.create_channel_for(name, "", "").await
}
}
impl BridgeProvider {
pub async fn create_channel_for(
&self,
name: &str,
user: &str,
host: &str,
) -> BridgeResult<AnyChannel> {
self.create_channel_with_creds(
name,
ClientCreds {
user: user.to_string(),
host: host.to_string(),
..Default::default()
},
)
.await
}
pub async fn create_channel_with_creds(
&self,
name: &str,
creds: ClientCreds,
) -> BridgeResult<AnyChannel> {
self.note_channel_created();
let access_ctx = AccessContext::with_creds(self.live_access(), creds);
if let Some(def) = self.groups.read().get(name).cloned() {
return Ok(AnyChannel::Group(
GroupChannel::new(self.db.clone(), def).with_access(access_ctx),
));
}
let parsed = epics_base_rs::server::database::filters::split_channel_name(name);
let resolution_name = parsed.record_path.as_str();
let (record_name, field) = epics_base_rs::server::database::parse_pv_name(resolution_name);
let field_upper = field.to_ascii_uppercase();
if parsed.json_suffix.is_none() {
let cache = self.record_cache.read().await;
if let Some(&(nt_type, value_dbf)) = cache.get(name) {
return Ok(AnyChannel::Single(
BridgeChannel::from_cached(
self.db.clone(),
name.to_string(),
record_name.to_string(),
field_upper,
nt_type,
value_dbf,
)
.with_access(access_ctx),
));
}
}
if self.db.has_name(resolution_name).await {
let channel = BridgeChannel::new(self.db.clone(), name).await?;
if parsed.json_suffix.is_none() {
let mut cache = self.record_cache.write().await;
cache.insert(name.to_string(), (channel.nt_type(), channel.value_dbf()));
}
return Ok(AnyChannel::Single(channel.with_access(access_ctx)));
}
Err(BridgeError::ChannelNotFound(name.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
struct ReadOnly;
impl AccessControl for ReadOnly {
fn can_write(&self, _: &str, _: &str, _: &str) -> bool {
false
}
}
struct DenySpecific(String);
impl AccessControl for DenySpecific {
fn can_read(&self, channel: &str, _: &str, _: &str) -> bool {
channel != self.0
}
fn can_write(&self, channel: &str, _: &str, _: &str) -> bool {
channel != self.0
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn acf_access_control_gates_qsrv_channels() {
use epics_base_rs::server::access_security::parse_acf;
use epics_base_rs::server::database::PvDatabase;
use epics_base_rs::server::records::ai::AiRecord;
let acf_text = r#"
UAG(admins) { admin }
ASG(SECURE) {
RULE(1, READ)
RULE(1, WRITE) { UAG(admins) }
}
"#;
let cfg = parse_acf(acf_text).unwrap();
let db = Arc::new(PvDatabase::new());
db.add_record("AI:SEC", Box::new(AiRecord::new(0.0)))
.await
.unwrap();
let rec = db.get_record("AI:SEC").await.unwrap();
rec.write().await.common.asg = "SECURE".to_string();
let acl = AcfAccessControl::new(db.clone(), cfg);
assert!(acl.can_read("AI:SEC", "guest", "anywhere"));
assert!(acl.can_read("AI:SEC", "admin", "anywhere"));
assert!(acl.can_write("AI:SEC", "admin", "anywhere"));
assert!(!acl.can_write("AI:SEC", "guest", "anywhere"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn br_r4_acf_method_authority_roles_field_asl() {
use epics_base_rs::server::access_security::parse_acf;
use epics_base_rs::server::database::PvDatabase;
use epics_base_rs::server::records::ai::AiRecord;
let acf_text = r#"
UAG(writers) { alice }
UAG(ops_role) { "role/ops" }
ASG(ASL_GATED) {
RULE(1, READ)
RULE(0, WRITE) { UAG(writers) }
}
ASG(METHOD_GATED) {
RULE(1, READ)
RULE(1, WRITE) { METHOD("x509") }
}
ASG(AUTHORITY_GATED) {
RULE(1, READ)
RULE(1, WRITE) { AUTHORITY("Trusted Root") }
}
ASG(ROLE_GATED) {
RULE(1, READ)
RULE(1, WRITE) { UAG(ops_role) }
}
"#;
let cfg = parse_acf(acf_text).unwrap();
let db = Arc::new(PvDatabase::new());
for name in &["AI:ASL0", "AI:ASL1", "AI:METH", "AI:AUTH", "AI:ROLE"] {
db.add_record(name, Box::new(AiRecord::new(0.0)))
.await
.unwrap();
}
{
let rec = db.get_record("AI:ASL0").await.unwrap();
let mut w = rec.write().await;
w.common.asg = "ASL_GATED".to_string();
w.common.asl = 0;
}
{
let rec = db.get_record("AI:ASL1").await.unwrap();
let mut w = rec.write().await;
w.common.asg = "ASL_GATED".to_string();
w.common.asl = 1;
}
{
let rec = db.get_record("AI:METH").await.unwrap();
rec.write().await.common.asg = "METHOD_GATED".to_string();
}
{
let rec = db.get_record("AI:AUTH").await.unwrap();
rec.write().await.common.asg = "AUTHORITY_GATED".to_string();
}
{
let rec = db.get_record("AI:ROLE").await.unwrap();
rec.write().await.common.asg = "ROLE_GATED".to_string();
}
let acl = AcfAccessControl::new(db.clone(), cfg);
assert!(
acl.can_write("AI:ASL0", "alice", "h"),
"ASL=0: alice should be allowed to write"
);
assert!(
!acl.can_write("AI:ASL1", "alice", "h"),
"ASL=1: RULE(0,WRITE) must be skipped → write denied"
);
let x509_creds = ClientCreds {
user: "alice".to_string(),
host: "h".to_string(),
method: "x509".to_string(),
authority: String::new(),
roles: Vec::new(),
};
assert!(
acl.can_write_creds("AI:METH", &x509_creds),
"x509 client must match METHOD(\"x509\") rule"
);
let ca_creds = ClientCreds {
user: "alice".to_string(),
host: "h".to_string(),
method: "ca".to_string(),
authority: String::new(),
roles: Vec::new(),
};
assert!(
!acl.can_write_creds("AI:METH", &ca_creds),
"ca client must NOT match METHOD(\"x509\")-only rule"
);
let trusted_creds = ClientCreds {
user: "alice".to_string(),
host: "h".to_string(),
method: "x509".to_string(),
authority: "Trusted Root".to_string(),
roles: Vec::new(),
};
assert!(
acl.can_write_creds("AI:AUTH", &trusted_creds),
"correct authority must match AUTHORITY(\"Trusted Root\")"
);
let other_ca_creds = ClientCreds {
user: "alice".to_string(),
host: "h".to_string(),
method: "x509".to_string(),
authority: "Other CA".to_string(),
roles: Vec::new(),
};
assert!(
!acl.can_write_creds("AI:AUTH", &other_ca_creds),
"wrong authority must NOT match"
);
let ops_creds = ClientCreds {
user: "bob".to_string(),
host: "h".to_string(),
method: "ca".to_string(),
authority: String::new(),
roles: vec!["ops".to_string()],
};
assert!(
acl.can_write_creds("AI:ROLE", &ops_creds),
"client with role 'ops' must match UAG entry 'role/ops'"
);
let no_role_creds = ClientCreds {
user: "bob".to_string(),
host: "h".to_string(),
method: "ca".to_string(),
authority: String::new(),
roles: Vec::new(),
};
assert!(
!acl.can_write_creds("AI:ROLE", &no_role_creds),
"client without 'ops' role must NOT write to ROLE_GATED"
);
}
#[test]
fn access_context_allow_all() {
let ctx = AccessContext::allow_all();
assert!(ctx.can_read("ANY"));
assert!(ctx.can_write("ANY"));
}
#[test]
fn access_context_read_only() {
let ctx = AccessContext::anonymous(Arc::new(ReadOnly));
assert!(ctx.can_read("X"));
assert!(!ctx.can_write("X"));
}
#[test]
fn access_context_with_identity() {
let ctx =
AccessContext::with_identity(Arc::new(AllowAllAccess), "alice".into(), "host1".into());
assert_eq!(ctx.user, "alice");
assert_eq!(ctx.host, "host1");
}
#[test]
fn access_context_deny_specific() {
let ctx = AccessContext::anonymous(Arc::new(DenySpecific("SECRET".to_string())));
assert!(ctx.can_read("PUBLIC"));
assert!(!ctx.can_read("SECRET"));
assert!(ctx.can_write("PUBLIC"));
assert!(!ctx.can_write("SECRET"));
}
#[test]
fn provider_set_access_control() {
let db = Arc::new(PvDatabase::new());
let provider = BridgeProvider::new(db);
assert!(provider.can_read("X", "u", "h"));
assert!(provider.can_write("X", "u", "h"));
provider.set_access_control(Arc::new(ReadOnly));
assert!(provider.can_read("X", "u", "h"));
assert!(!provider.can_write("X", "u", "h"));
}
#[tokio::test]
async fn read_only_channel_blocks_writes() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(ReadOnly));
let ch = BridgeChannel::from_cached(
db,
"PROT".to_string(),
"PROT".to_string(),
"VAL".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(access);
let mut put_struct = PvStructure::new("epics:nt/NTScalar:1.0");
put_struct.fields.push((
"value".into(),
epics_pva_rs::pvdata::PvField::Scalar(epics_pva_rs::pvdata::ScalarValue::Double(2.0)),
));
let result = ch.put(&put_struct).await;
assert!(result.is_err(), "expected access denied");
let err = format!("{}", result.unwrap_err());
assert!(
err.contains("denied"),
"expected denial message, got: {err}"
);
}
#[tokio::test]
async fn deny_specific_channel_blocks_named() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(DenySpecific("BLOCKED".to_string())));
let ch = BridgeChannel::from_cached(
db.clone(),
"BLOCKED".to_string(),
"BLOCKED".to_string(),
"VAL".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(access);
let req = PvStructure::new("");
let result = ch.get(&req).await;
assert!(result.is_err(), "expected read denied for BLOCKED");
let ok_access = AccessContext::anonymous(Arc::new(DenySpecific("BLOCKED".to_string())));
let ch2 = BridgeChannel::from_cached(
db,
"ALLOWED".to_string(),
"ALLOWED".to_string(),
"VAL".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(ok_access);
let result = ch2.get(&req).await;
let err = format!("{:?}", result.unwrap_err());
assert!(
!err.contains("denied"),
"ALLOWED channel should pass access check, got: {err}"
);
}
struct WriteOnly;
impl AccessControl for WriteOnly {
fn can_read(&self, _: &str, _: &str, _: &str) -> bool {
false
}
}
#[tokio::test]
async fn create_monitor_blocks_when_read_denied() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(WriteOnly));
let ch = BridgeChannel::from_cached(
db,
"PROT".to_string(),
"PROT".to_string(),
"VAL".to_string(),
super::super::pvif::NtType::Scalar,
epics_base_rs::types::DbFieldType::Double,
)
.with_access(access);
let result = ch.create_monitor().await;
match result {
Ok(_) => panic!("expected monitor create denied, got Ok"),
Err(e) => {
let err = format!("{e}");
assert!(
err.contains("monitor create denied"),
"expected monitor denial message, got: {err}"
);
}
}
}
#[test]
fn live_access_proxy_observes_policy_swap() {
let db = Arc::new(PvDatabase::new());
let provider = BridgeProvider::new(db);
let ctx =
AccessContext::with_identity(provider.live_access(), "alice".into(), "host1".into());
assert!(ctx.can_read("ANY"));
assert!(ctx.can_write("ANY"));
provider.set_access_control(Arc::new(DenySpecific("SECRET".into())));
assert!(ctx.can_read("ALLOWED"));
assert!(!ctx.can_read("SECRET"), "swap must be observed live");
assert!(!ctx.can_write("SECRET"));
provider.set_access_control(Arc::new(ReadOnly));
assert!(ctx.can_read("X"));
assert!(
!ctx.can_write("X"),
"policy swap must take effect immediately"
);
provider.set_access_control(Arc::new(AllowAllAccess));
assert!(ctx.can_write("X"));
}
#[tokio::test]
async fn bridge_monitor_start_blocks_when_read_denied() {
let db = Arc::new(PvDatabase::new());
let access = AccessContext::anonymous(Arc::new(WriteOnly));
let mut monitor = super::super::monitor::BridgeMonitor::new(
db,
"PROT".to_string(),
"VAL".to_string(),
super::super::pvif::NtType::Scalar,
)
.with_access(access);
let result = monitor.start().await;
assert!(result.is_err(), "expected monitor start denied");
let err = format!("{}", result.unwrap_err());
assert!(
err.contains("monitor read denied"),
"expected start denial, got: {err}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mr_r12_legacy_path_matches_method_anonymous_rule() {
use epics_base_rs::server::access_security::parse_acf;
use epics_base_rs::server::database::PvDatabase;
use epics_base_rs::server::records::ai::AiRecord;
let acf_text = r#"
ASG(ANON_GATED) {
RULE(1, READ) { METHOD("anonymous") }
}
"#;
let cfg = parse_acf(acf_text).unwrap();
let db = Arc::new(PvDatabase::new());
db.add_record("AI:ANON", Box::new(AiRecord::new(0.0)))
.await
.unwrap();
{
let rec = db.get_record("AI:ANON").await.unwrap();
rec.write().await.common.asg = "ANON_GATED".to_string();
}
let acl = AcfAccessControl::new(db.clone(), cfg);
assert!(
acl.can_read("AI:ANON", "alice", "h"),
"legacy can_read must match METHOD(\"anonymous\") rule"
);
let id_creds = ClientCreds {
user: "alice".to_string(),
host: "h".to_string(),
method: String::new(),
authority: String::new(),
roles: Vec::new(),
};
assert!(
acl.can_read_creds("AI:ANON", &id_creds),
"empty-method ClientCreds must match METHOD(\"anonymous\") rule"
);
let ca_creds = ClientCreds {
user: "alice".to_string(),
host: "h".to_string(),
method: "ca".to_string(),
authority: String::new(),
roles: Vec::new(),
};
assert!(
!acl.can_read_creds("AI:ANON", &ca_creds),
"an explicit 'ca' method must NOT match a METHOD(\"anonymous\")-only rule"
);
}
}