use std::collections::HashMap;
use crate::error::{CaError, CaResult};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AccessLevel {
NoAccess,
Read,
ReadWrite,
}
#[derive(Debug, Clone)]
pub struct AccessChecked {
pv_name: String,
level: AccessLevel,
rule_was_trap: bool,
_seal: AccessSeal,
}
#[derive(Debug, Clone)]
struct AccessSeal;
impl AccessChecked {
pub fn pv_name(&self) -> &str {
&self.pv_name
}
pub fn level(&self) -> AccessLevel {
self.level
}
pub fn allows_read(&self) -> bool {
!matches!(self.level, AccessLevel::NoAccess)
}
pub fn allows_write(&self) -> bool {
matches!(self.level, AccessLevel::ReadWrite)
}
pub fn rule_was_trap(&self) -> bool {
self.rule_was_trap
}
}
#[derive(Clone)]
pub struct AccessGate {
inner: AccessGateInner,
acl_version: AclVersionSource,
}
#[derive(Clone)]
enum AclVersionSource {
Atomic(std::sync::Arc<std::sync::atomic::AtomicU64>),
Aggregator(std::sync::Arc<dyn Fn() -> u64 + Send + Sync>),
}
pub type AsgAslResolver = std::sync::Arc<
dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = (String, u8)> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
enum AccessGateInner {
Required {
acf: std::sync::Arc<tokio::sync::RwLock<Option<AccessSecurityConfig>>>,
resolver: AsgAslResolver,
},
Open,
}
impl AccessGate {
pub fn required(
acf: std::sync::Arc<tokio::sync::RwLock<Option<AccessSecurityConfig>>>,
resolver: AsgAslResolver,
) -> Self {
Self::required_with_version(
acf,
resolver,
std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
)
}
pub fn required_with_version(
acf: std::sync::Arc<tokio::sync::RwLock<Option<AccessSecurityConfig>>>,
resolver: AsgAslResolver,
acl_version: std::sync::Arc<std::sync::atomic::AtomicU64>,
) -> Self {
Self {
inner: AccessGateInner::Required { acf, resolver },
acl_version: AclVersionSource::Atomic(acl_version),
}
}
pub fn open() -> Self {
Self {
inner: AccessGateInner::Open,
acl_version: AclVersionSource::Atomic(std::sync::Arc::new(
std::sync::atomic::AtomicU64::new(0),
)),
}
}
pub fn open_with_aggregator(f: std::sync::Arc<dyn Fn() -> u64 + Send + Sync>) -> Self {
Self {
inner: AccessGateInner::Open,
acl_version: AclVersionSource::Aggregator(f),
}
}
pub fn acl_version(&self) -> u64 {
match &self.acl_version {
AclVersionSource::Atomic(a) => a.load(std::sync::atomic::Ordering::Acquire),
AclVersionSource::Aggregator(f) => f(),
}
}
pub fn bump_acl_version(&self) {
if let AclVersionSource::Atomic(a) = &self.acl_version {
a.fetch_add(1, std::sync::atomic::Ordering::Release);
}
}
pub async fn check(
&self,
pv_name: impl Into<String>,
host: &str,
user: &str,
method: &str,
authority: &str,
) -> AccessChecked {
let pv_name = pv_name.into();
let (level, rule_was_trap) = match &self.inner {
AccessGateInner::Open => (AccessLevel::ReadWrite, false),
AccessGateInner::Required { acf, resolver } => {
let guard = acf.read().await;
match *guard {
None => (AccessLevel::ReadWrite, false),
Some(ref cfg) => {
let (asg, asl) = resolver(pv_name.clone()).await;
cfg.check_access_method_trap(&asg, host, user, asl, method, authority)
}
}
}
};
AccessChecked {
pv_name,
level,
rule_was_trap,
_seal: AccessSeal,
}
}
}
#[cfg(test)]
mod access_checked_tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn open_gate_grants_read_write() {
let gate = AccessGate::open();
let checked = gate.check("any:pv", "h", "u", "anonymous", "").await;
assert_eq!(checked.level(), AccessLevel::ReadWrite);
assert!(checked.allows_read());
assert!(checked.allows_write());
assert_eq!(checked.pv_name(), "any:pv");
}
#[tokio::test]
async fn required_gate_with_no_acf_attached_is_permissive() {
let cell = Arc::new(tokio::sync::RwLock::new(None));
let resolver: AsgAslResolver =
Arc::new(|_pv| Box::pin(async { ("DEFAULT".to_string(), 0u8) }));
let gate = AccessGate::required(cell, resolver);
let checked = gate.check("any:pv", "h", "u", "anonymous", "").await;
assert_eq!(checked.level(), AccessLevel::ReadWrite);
}
#[tokio::test]
async fn required_gate_with_acf_denies_unprivileged_peer() {
let cfg = parse_acf(
r#"
UAG(ops) { alice }
ASG(DEFAULT) {
RULE(0, READ) { UAG(ops) }
}
"#,
)
.unwrap();
let cell = Arc::new(tokio::sync::RwLock::new(Some(cfg)));
let resolver: AsgAslResolver =
Arc::new(|_pv| Box::pin(async { ("DEFAULT".to_string(), 0u8) }));
let gate = AccessGate::required(cell, resolver);
let allowed = gate.check("x", "h", "alice", "anonymous", "").await;
assert!(allowed.allows_read());
assert!(!allowed.allows_write());
let denied = gate.check("x", "h", "intruder", "anonymous", "").await;
assert_eq!(denied.level(), AccessLevel::NoAccess);
assert!(!denied.allows_read());
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum RuleAccess {
#[default]
None,
Read,
Write,
}
#[derive(Debug, Clone, Default)]
pub struct AccessRule {
pub level: u8,
pub access: RuleAccess,
pub uag: Vec<String>,
pub hag: Vec<String>,
pub method: Vec<String>,
pub authority: Vec<String>,
pub trap: bool,
pub calc: Option<String>,
pub ignore: bool,
}
fn rule_access(rule: &AccessRule) -> AccessLevel {
if rule.ignore {
return AccessLevel::NoAccess;
}
match rule.access {
RuleAccess::None => AccessLevel::NoAccess,
RuleAccess::Read => AccessLevel::Read,
RuleAccess::Write => AccessLevel::ReadWrite,
}
}
fn rule_rank(level: AccessLevel) -> u8 {
match level {
AccessLevel::NoAccess => 0,
AccessLevel::Read => 1,
AccessLevel::ReadWrite => 2,
}
}
#[derive(Debug, Clone, Default)]
pub struct AccessSecurityGroup {
pub rules: Vec<AccessRule>,
pub inp: Vec<AsgInp>,
}
#[derive(Debug, Clone)]
pub struct AsgInp {
pub index: u8,
pub link: String,
}
#[derive(Debug, Clone)]
pub struct AccessSecurityConfig {
pub uag: HashMap<String, Vec<String>>,
pub hag: HashMap<String, Vec<String>>,
pub asg: HashMap<String, AccessSecurityGroup>,
pub unknown_access: AccessLevel,
}
impl AccessSecurityConfig {
pub fn check_access(&self, asg_name: &str, host: &str, user: &str) -> AccessLevel {
self.check_access_asl(asg_name, host, user, 0)
}
pub fn check_access_method(
&self,
asg_name: &str,
host: &str,
user: &str,
record_asl: u8,
method: &str,
authority: &str,
) -> AccessLevel {
self.check_access_method_trap(asg_name, host, user, record_asl, method, authority)
.0
}
pub fn check_access_method_trap(
&self,
asg_name: &str,
host: &str,
user: &str,
record_asl: u8,
method: &str,
authority: &str,
) -> (AccessLevel, bool) {
let asg = match self.asg.get(asg_name) {
Some(a) => a,
None => match self.asg.get("DEFAULT") {
Some(a) => a,
None => return (AccessLevel::NoAccess, false),
},
};
let mut access = AccessLevel::NoAccess;
let mut trap = false;
for rule in &asg.rules {
if rule.ignore {
continue;
}
if access == AccessLevel::ReadWrite {
break;
}
if rule_rank(rule_access(rule)) <= rule_rank(access) {
continue;
}
if record_asl > rule.level {
continue;
}
let user_match = rule.uag.is_empty()
|| rule.uag.iter().any(|g| {
self.uag
.get(g)
.map(|members| members.iter().any(|m| m == user))
.unwrap_or(false)
});
if !user_match {
continue;
}
let host_lc = host.to_ascii_lowercase();
let host_match = rule.hag.is_empty()
|| rule.hag.iter().any(|g| {
self.hag
.get(g)
.map(|members| members.iter().any(|m| m.eq_ignore_ascii_case(&host_lc)))
.unwrap_or(false)
});
if !host_match {
continue;
}
let method_match = rule.method.is_empty()
|| rule.method.iter().any(|m| m.eq_ignore_ascii_case(method));
if !method_match {
continue;
}
let authority_match = rule.authority.is_empty()
|| rule
.authority
.iter()
.any(|a| a.eq_ignore_ascii_case(authority));
if !authority_match {
continue;
}
access = rule_access(rule);
trap = rule.trap;
}
(access, trap)
}
pub fn check_access_asl(
&self,
asg_name: &str,
host: &str,
user: &str,
record_asl: u8,
) -> AccessLevel {
self.check_access_method(asg_name, host, user, record_asl, "", "")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrapWriteOp {
BeforeWrite,
AfterWrite,
}
#[derive(Debug, Clone, Copy)]
pub struct TrapWriteMessage<'a> {
pub op: TrapWriteOp,
pub pv_name: &'a str,
pub user: &'a str,
pub host: &'a str,
pub peer: &'a str,
pub value_str: &'a str,
pub dbr_type: u16,
pub no_elements: u32,
pub event_id: u64,
pub status: Option<&'a str>,
pub rule_was_trap: bool,
}
static TRAP_WRITE_EVENT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
pub fn next_trap_write_event_id() -> u64 {
TRAP_WRITE_EVENT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
pub type TrapWriteListener = std::sync::Arc<dyn Fn(&TrapWriteMessage<'_>) + Send + Sync>;
pub struct TrapWriteListenerHandle {
id: u64,
}
impl Drop for TrapWriteListenerHandle {
fn drop(&mut self) {
if let Some(reg) = TRAP_WRITE_REGISTRY.get() {
let mut guard = reg.write().expect("trap-write registry poisoned");
guard.retain(|(id, _)| *id != self.id);
}
}
}
static TRAP_WRITE_REGISTRY: std::sync::OnceLock<std::sync::RwLock<Vec<(u64, TrapWriteListener)>>> =
std::sync::OnceLock::new();
static TRAP_WRITE_NEXT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
fn trap_write_registry() -> &'static std::sync::RwLock<Vec<(u64, TrapWriteListener)>> {
TRAP_WRITE_REGISTRY.get_or_init(|| std::sync::RwLock::new(Vec::new()))
}
pub fn register_trap_write_listener(listener: TrapWriteListener) -> TrapWriteListenerHandle {
let id = TRAP_WRITE_NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let mut guard = trap_write_registry()
.write()
.expect("trap-write registry poisoned");
guard.push((id, listener));
TrapWriteListenerHandle { id }
}
pub fn has_trap_write_listeners() -> bool {
let Some(reg) = TRAP_WRITE_REGISTRY.get() else {
return false;
};
let guard = reg.read().expect("trap-write registry poisoned");
!guard.is_empty()
}
pub fn dispatch_trap_write(msg: &TrapWriteMessage<'_>) {
let Some(reg) = TRAP_WRITE_REGISTRY.get() else {
return;
};
let snapshot: Vec<TrapWriteListener> = {
let guard = reg.read().expect("trap-write registry poisoned");
if guard.is_empty() {
return;
}
guard.iter().map(|(_, l)| l.clone()).collect()
};
for listener in snapshot {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
listener(msg);
}));
if let Err(payload) = result {
let descr = if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"(non-string panic payload)".to_string()
};
tracing::error!(
target: "epics_base_rs::server::access_security",
pv = msg.pv_name,
event_id = msg.event_id,
op = ?msg.op,
panic = %descr,
"TRAPWRITE listener panicked — isolating; remaining listeners will still run. \
C asTrapWriteWithData has no unwind concept; this is a Rust-only safety net \
to keep the per-circuit task alive."
);
}
}
}
static ASG_CHANGE_BROADCAST: std::sync::OnceLock<tokio::sync::broadcast::Sender<()>> =
std::sync::OnceLock::new();
fn asg_change_broadcast() -> &'static tokio::sync::broadcast::Sender<()> {
ASG_CHANGE_BROADCAST.get_or_init(|| {
let (tx, _rx) = tokio::sync::broadcast::channel(16);
tx
})
}
pub fn notify_asg_field_changed() {
let _ = asg_change_broadcast().send(());
}
pub fn subscribe_asg_changes() -> tokio::sync::broadcast::Receiver<()> {
asg_change_broadcast().subscribe()
}
pub fn parse_acf(content: &str) -> CaResult<AccessSecurityConfig> {
let mut config = AccessSecurityConfig {
uag: HashMap::new(),
hag: HashMap::new(),
asg: HashMap::new(),
unknown_access: AccessLevel::Read,
};
config
.asg
.insert("DEFAULT".to_string(), AccessSecurityGroup::default());
let mut chars = content.chars().peekable();
let mut buf = String::new();
while chars.peek().is_some() {
skip_ws_comments(&mut chars);
buf.clear();
read_word(&mut chars, &mut buf);
match buf.as_str() {
"UAG" => {
let name = read_paren_name(&mut chars)?;
let members = read_brace_list(&mut chars)?;
config.uag.insert(name, members);
}
"HAG" => {
let name = read_paren_name(&mut chars)?;
let members = read_brace_list(&mut chars)?;
let lowered: Vec<String> = members.iter().map(|m| m.to_ascii_lowercase()).collect();
let expanded = expand_hag_members(&lowered);
config.hag.insert(name, expanded);
}
"ASG" => {
let name = read_paren_name(&mut chars)?;
let asg = parse_asg_body(&mut chars)?;
config.asg.insert(name, asg);
}
"" => {
match chars.peek() {
Some(&c) if matches!(c, '(' | ')' | '{' | '}' | ',') => {
return Err(CaError::Protocol(format!(
"ACF: unexpected '{c}' where a top-level block keyword is expected"
)));
}
_ => break,
}
}
other => {
skip_unknown_top_level_block(other, &mut chars)?;
}
}
}
Ok(config)
}
fn skip_unknown_top_level_block(
keyword: &str,
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> CaResult<()> {
skip_ws_comments(chars);
if chars.peek() != Some(&'(') {
return Err(CaError::Protocol(format!(
"ACF: unexpected token '{keyword}' — expected a top-level \
UAG/HAG/ASG block or an unknown keyword followed by '('"
)));
}
let mut depth = 0;
let mut closed = false;
while let Some(&c) = chars.peek() {
chars.next();
match c {
'(' => depth += 1,
')' => {
depth -= 1;
if depth == 0 {
closed = true;
break;
}
}
_ => {}
}
}
if !closed {
return Err(CaError::Protocol(format!(
"ACF: unbalanced '(' in unsupported top-level block '{keyword}'"
)));
}
skip_ws_comments(chars);
if chars.peek() == Some(&'{') {
let mut depth = 0;
let mut closed = false;
while let Some(&c) = chars.peek() {
chars.next();
match c {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
closed = true;
break;
}
}
_ => {}
}
}
if !closed {
return Err(CaError::Protocol(format!(
"ACF: unbalanced '{{' in unsupported top-level block '{keyword}'"
)));
}
}
tracing::warn!(
target: "epics_base_rs::access_security",
keyword = %keyword,
"ACF: ignoring unsupported top-level block"
);
Ok(())
}
fn expand_hag_members(members: &[String]) -> Vec<String> {
use std::net::ToSocketAddrs;
let mut out: Vec<String> = Vec::with_capacity(members.len());
for m in members {
out.push(m.clone());
if m.parse::<std::net::IpAddr>().is_ok() {
continue;
}
match format!("{m}:0").to_socket_addrs() {
Ok(iter) => {
for sa in iter {
let ip = sa.ip().to_string();
if !out.iter().any(|s| s == &ip) {
out.push(ip);
}
}
}
Err(e) => {
tracing::debug!(
target: "epics_base_rs::access_security",
host = %m,
error = %e,
"ACF HAG: DNS lookup failed; keeping literal entry (libcom 932e9f3 soft fallback)"
);
}
}
}
out
}
fn skip_ws_comments(chars: &mut std::iter::Peekable<std::str::Chars>) {
while let Some(&c) = chars.peek() {
if c.is_whitespace() {
chars.next();
} else if c == '#' {
while let Some(&c) = chars.peek() {
chars.next();
if c == '\n' {
break;
}
}
} else {
break;
}
}
}
fn read_word(chars: &mut std::iter::Peekable<std::str::Chars>, buf: &mut String) {
while let Some(&c) = chars.peek() {
if c.is_alphanumeric() || c == '_' {
buf.push(c);
chars.next();
} else {
break;
}
}
}
fn read_paren_name(chars: &mut std::iter::Peekable<std::str::Chars>) -> CaResult<String> {
skip_ws_comments(chars);
if chars.next() != Some('(') {
return Err(CaError::Protocol("ACF: expected '('".into()));
}
skip_ws_comments(chars);
let mut name = String::new();
if chars.peek() == Some(&'"') {
chars.next();
let mut closed = false;
while let Some(&c) = chars.peek() {
chars.next();
if c == '"' {
closed = true;
break;
}
name.push(c);
}
if !closed {
return Err(CaError::Protocol("ACF: unterminated quoted name".into()));
}
skip_ws_comments(chars);
if chars.next() != Some(')') {
return Err(CaError::Protocol(
"ACF: expected ')' after quoted name".into(),
));
}
return Ok(name);
}
loop {
match chars.peek() {
Some(&')') => {
chars.next();
break;
}
Some(&c) if c.is_whitespace() => {
skip_ws_comments(chars);
match chars.peek() {
Some(&')') => {
chars.next();
break;
}
Some(_) => {
return Err(CaError::Protocol(
"ACF: whitespace inside parenthesised name".into(),
));
}
None => {
return Err(CaError::Protocol(
"ACF: unterminated '(' — missing ')'".into(),
));
}
}
}
Some(&c) => {
name.push(c);
chars.next();
}
None => {
return Err(CaError::Protocol(
"ACF: unterminated '(' — missing ')'".into(),
));
}
}
}
Ok(name)
}
fn read_brace_list(chars: &mut std::iter::Peekable<std::str::Chars>) -> CaResult<Vec<String>> {
skip_ws_comments(chars);
if chars.next() != Some('{') {
return Err(CaError::Protocol("ACF: expected '{'".into()));
}
let mut items = Vec::new();
let mut current = String::new();
loop {
skip_ws_comments(chars);
match chars.peek() {
Some(&'}') => {
chars.next();
break;
}
Some(&',') => {
chars.next();
if !current.is_empty() {
items.push(current.clone());
current.clear();
}
}
Some(&'"') => {
chars.next(); if !current.is_empty() {
items.push(current.clone());
current.clear();
}
let mut quoted = String::new();
loop {
match chars.next() {
Some('"') => break,
Some('\\') => {
if let Some(esc) = chars.next() {
quoted.push(esc);
}
}
Some('\n') | None => {
return Err(CaError::Protocol(
"ACF: unterminated quoted string".into(),
));
}
Some(c) => quoted.push(c),
}
}
if !quoted.is_empty() {
items.push(quoted);
}
}
Some(&c)
if c.is_alphanumeric()
|| matches!(c, '_' | '.' | '-' | '+' | ':' | '[' | ']' | '<' | '>' | ';') =>
{
current.push(c);
chars.next();
}
Some(_) => {
chars.next();
}
None => return Err(CaError::Protocol("ACF: unterminated '{'".into())),
}
}
if !current.is_empty() {
items.push(current);
}
Ok(items)
}
fn parse_asg_body(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> CaResult<AccessSecurityGroup> {
skip_ws_comments(chars);
if chars.next() != Some('{') {
return Err(CaError::Protocol("ACF: expected '{' after ASG name".into()));
}
let mut asg = AccessSecurityGroup::default();
loop {
skip_ws_comments(chars);
match chars.peek() {
Some(&'}') => {
chars.next();
break;
}
Some(_) => {
let mut kw = String::new();
read_word(chars, &mut kw);
if kw == "RULE" {
let rule = parse_rule(chars)?;
asg.rules.push(rule);
} else if let Some(stripped) = kw.strip_prefix("INP") {
let index = match parse_inp_index(stripped) {
Some(i) => i,
None => {
return Err(CaError::Protocol(format!(
"ACF: invalid INP link selector 'INP{stripped}' \
(expected INPA..INPU)"
)));
}
};
let link = read_paren_name(chars)?;
asg.inp.push(AsgInp { index, link });
} else if kw.is_empty() {
chars.next(); }
}
None => return Err(CaError::Protocol("ACF: unterminated ASG".into())),
}
}
Ok(asg)
}
fn parse_inp_index(suffix: &str) -> Option<u8> {
let mut it = suffix.chars();
let c = it.next()?;
if it.next().is_some() {
return None; }
let c = c.to_ascii_uppercase();
if ('A'..='U').contains(&c) {
Some((c as u8) - b'A')
} else {
None
}
}
fn parse_rule(chars: &mut std::iter::Peekable<std::str::Chars>) -> CaResult<AccessRule> {
skip_ws_comments(chars);
if chars.next() != Some('(') {
return Err(CaError::Protocol("ACF: expected '(' after RULE".into()));
}
skip_ws_comments(chars);
let mut level_str = String::new();
if matches!(chars.peek(), Some('+') | Some('-')) {
level_str.push(chars.next().unwrap());
}
while let Some(&c) = chars.peek() {
if c.is_ascii_digit() {
level_str.push(c);
chars.next();
} else {
break;
}
}
let level_num: i64 = level_str.parse().map_err(|_| {
CaError::Protocol(format!(
"ACF: RULE level must be an integer, got '{level_str}'"
))
})?;
if level_num < 0 {
return Err(CaError::Protocol(format!(
"ACF: RULE LEVEL must be positive: {level_num}"
)));
}
let level: u8 = u8::try_from(level_num)
.map_err(|_| CaError::Protocol(format!("ACF: RULE level out of range: {level_num}")))?;
skip_ws_comments(chars);
if chars.peek() == Some(&',') {
chars.next();
}
skip_ws_comments(chars);
let mut access_str = String::new();
read_word(chars, &mut access_str);
let (access, mut ignore) = if access_str.eq_ignore_ascii_case("WRITE") {
(RuleAccess::Write, false)
} else if access_str.eq_ignore_ascii_case("READ") {
(RuleAccess::Read, false)
} else if access_str.eq_ignore_ascii_case("NONE") {
(RuleAccess::None, false)
} else {
tracing::warn!(
target: "epics_base_rs::access_security",
keyword = %access_str,
"ACF: ignoring RULE with unsupported access keyword"
);
(RuleAccess::None, true)
};
let mut trap = false;
skip_ws_comments(chars);
if chars.peek() == Some(&',') {
chars.next();
skip_ws_comments(chars);
let mut log_opt = String::new();
read_word(chars, &mut log_opt);
if log_opt.eq_ignore_ascii_case("TRAPWRITE") {
trap = true;
} else if !log_opt.eq_ignore_ascii_case("NOTRAPWRITE") {
return Err(CaError::Protocol(format!(
"ACF: RULE log option must be TRAPWRITE or NOTRAPWRITE, got '{log_opt}'"
)));
}
}
skip_ws_comments(chars);
if chars.peek() == Some(&')') {
chars.next();
}
let mut uag = Vec::new();
let mut hag = Vec::new();
let mut method = Vec::new();
let mut authority = Vec::new();
let mut calc: Option<String> = None;
skip_ws_comments(chars);
if chars.peek() == Some(&'{') {
chars.next();
loop {
skip_ws_comments(chars);
match chars.peek() {
Some(&'}') => {
chars.next();
break;
}
Some(_) => {
let mut kw = String::new();
read_word(chars, &mut kw);
if kw == "UAG" {
let name = read_paren_name(chars)?;
uag.push(name);
} else if kw == "HAG" {
let name = read_paren_name(chars)?;
hag.push(name);
} else if kw == "METHOD" {
method.extend(read_paren_string_list(chars)?);
} else if kw == "AUTHORITY" {
authority.extend(read_paren_string_list(chars)?);
} else if kw == "CALC" {
let expr = read_paren_name_raw(chars)?;
calc = Some(expr);
} else if kw.is_empty() {
chars.next();
} else {
tracing::warn!(
target: "epics_base_rs::access_security",
keyword = %kw,
"ACF: ignoring RULE with unsupported keyword — rule disabled"
);
ignore = true;
skip_ws_comments(chars);
if chars.peek() == Some(&'(') {
let _ = read_paren_name(chars)?;
}
}
}
None => break,
}
}
}
if let Some(ref expr) = calc {
match crate::calc::compile(expr) {
Ok(_) => {
tracing::warn!(
target: "epics_base_rs::access_security",
calc = %expr,
"ACF: CALC-gated RULE disabled (no INP* link \
resolution in this crate; rule fails CLOSED — \
access decisions will diverge from EPICS Base / \
pvxs which would evaluate the expression \
dynamically). See BR-R32."
);
ignore = true;
}
Err(e) => {
return Err(CaError::Protocol(format!(
"ACF: bad CALC expression '{expr}': {e}"
)));
}
}
}
Ok(AccessRule {
level,
access,
uag,
hag,
method,
authority,
trap,
calc,
ignore,
})
}
fn read_paren_name_raw(chars: &mut std::iter::Peekable<std::str::Chars>) -> CaResult<String> {
skip_ws_comments(chars);
if chars.next() != Some('(') {
return Err(CaError::Protocol("ACF: expected '(' after CALC".into()));
}
skip_ws_comments(chars);
let mut body = String::new();
if chars.peek() == Some(&'"') {
chars.next();
while let Some(&c) = chars.peek() {
chars.next();
if c == '"' {
break;
}
body.push(c);
}
skip_ws_comments(chars);
if chars.next() != Some(')') {
return Err(CaError::Protocol(
"ACF: expected ')' after CALC expression".into(),
));
}
} else {
while let Some(&c) = chars.peek() {
if c == ')' {
chars.next();
break;
}
body.push(c);
chars.next();
}
}
Ok(body.trim().to_string())
}
fn read_paren_string_list(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> CaResult<Vec<String>> {
skip_ws_comments(chars);
if chars.next() != Some('(') {
return Err(CaError::Protocol(
"ACF: expected '(' after METHOD/AUTHORITY".into(),
));
}
let mut items = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
loop {
match chars.peek() {
Some(&'"') => {
chars.next();
in_quotes = !in_quotes;
}
Some(&')') if !in_quotes => {
chars.next();
break;
}
Some(&',') if !in_quotes => {
chars.next();
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
items.push(trimmed);
}
current.clear();
}
Some(&c) => {
current.push(c);
chars.next();
}
None => {
return Err(CaError::Protocol(
"ACF: unterminated METHOD/AUTHORITY list".into(),
));
}
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
items.push(trimmed);
}
Ok(items)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_acf_basic() {
let acf = r#"
UAG(admins) { user1, user2 }
HAG(operators) { host1, host2 }
ASG(DEFAULT) {
RULE(1, WRITE) { UAG(admins) HAG(operators) }
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(config.uag.get("admins").unwrap(), &["user1", "user2"]);
assert_eq!(config.hag.get("operators").unwrap(), &["host1", "host2"]);
assert!(config.asg.contains_key("DEFAULT"));
assert_eq!(config.asg["DEFAULT"].rules.len(), 2);
}
#[test]
fn test_parse_acf_hag_uag() {
let acf = r#"
UAG(ops) { alice, bob }
HAG(lab) { lab-pc1.invalid }
ASG(SECURE) {
RULE(1, WRITE) { UAG(ops) HAG(lab) }
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(config.uag["ops"], vec!["alice", "bob"]);
assert_eq!(config.hag["lab"], vec!["lab-pc1.invalid"]);
}
#[test]
fn hag_dns_resolution_appends_ip_for_match() {
let acf = r#"
HAG(local) { localhost }
ASG(DEFAULT) {
RULE(1, WRITE) { HAG(local) }
}
"#;
let config = parse_acf(acf).unwrap();
let entries = &config.hag["local"];
assert!(
entries.contains(&"localhost".to_string()),
"literal hostname always preserved"
);
assert!(
entries.iter().any(|s| s == "127.0.0.1"),
"resolved IPv4 appended for IP-presenting peers"
);
}
#[test]
fn hag_unresolvable_name_does_not_abort_parser() {
let acf = r#"
HAG(quarantine) { gone.invalid, alive.invalid }
ASG(DEFAULT) {
RULE(1, WRITE) { HAG(quarantine) }
}
"#;
let config = parse_acf(acf).expect("parser must not abort on bad DNS");
let entries = &config.hag["quarantine"];
assert_eq!(
entries.len(),
2,
"literal entries preserved verbatim; no resolved IPs appended"
);
assert_eq!(entries[0], "gone.invalid");
assert_eq!(entries[1], "alive.invalid");
}
#[test]
fn test_check_access_default_rw() {
let acf = "ASG(DEFAULT) { RULE(1, WRITE) RULE(1, READ) }";
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("DEFAULT", "host1", "user1"),
AccessLevel::ReadWrite
);
}
#[test]
fn test_check_access_read_only() {
let acf = r#"
UAG(admins) { admin1 }
ASG(READONLY) {
RULE(1, READ)
RULE(1, WRITE) { UAG(admins) }
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("READONLY", "host1", "admin1"),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access("READONLY", "host1", "regular"),
AccessLevel::Read
);
}
#[test]
fn test_check_access_hag_uag_match() {
let acf = r#"
UAG(ops) { alice }
HAG(lab) { lab-pc1 }
ASG(CONTROLLED) {
RULE(1, WRITE) { UAG(ops) HAG(lab) }
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("CONTROLLED", "lab-pc1", "alice"),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access("CONTROLLED", "other-host", "alice"),
AccessLevel::Read
);
assert_eq!(
config.check_access("CONTROLLED", "lab-pc1", "bob"),
AccessLevel::Read
);
}
#[test]
fn test_check_access_unknown_user() {
let acf = r#"
ASG(DEFAULT) {
RULE(1, WRITE)
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("DEFAULT", "", ""),
AccessLevel::ReadWrite
);
}
#[test]
fn parse_acf_captures_method_and_authority() {
let acf = r#"
ASG(SECURE) {
RULE(1, WRITE) {
METHOD("ca", "x509")
AUTHORITY("ANL CA")
}
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
let asg = &config.asg["SECURE"];
assert_eq!(asg.rules.len(), 2);
assert_eq!(asg.rules[0].method, vec!["ca", "x509"]);
assert_eq!(asg.rules[0].authority, vec!["ANL CA"]);
assert!(
asg.rules[1].method.is_empty(),
"READ rule must not inherit METHOD list",
);
assert!(asg.rules[1].authority.is_empty());
}
#[test]
fn tls_x509_acf_rule_grants_write_on_issuer_match() {
let cfg = parse_acf(
r#"
ASG(TLS_ONLY) {
RULE(1, WRITE) { METHOD("x509") AUTHORITY("CN=ops-ca, O=Lab") }
RULE(1, READ)
}
"#,
)
.unwrap();
assert_eq!(
cfg.check_access_method("TLS_ONLY", "h", "u", 0, "", ""),
AccessLevel::Read
);
assert_eq!(
cfg.check_access_method("TLS_ONLY", "h", "u", 0, "x509", "CN=other-ca"),
AccessLevel::Read
);
assert_eq!(
cfg.check_access_method("TLS_ONLY", "h", "u", 0, "x509", "CN=ops-ca, O=Lab"),
AccessLevel::ReadWrite
);
}
#[test]
fn check_access_method_gates_on_method() {
let acf = r#"
ASG(METHOD_GATED) {
RULE(1, WRITE) {
METHOD("x509")
}
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_method("METHOD_GATED", "h", "u", 0, "x509", ""),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access_method("METHOD_GATED", "h", "u", 0, "ca", ""),
AccessLevel::Read
);
}
#[test]
fn check_access_method_gates_on_authority() {
let acf = r#"
ASG(AUTH_GATED) {
RULE(1, WRITE) {
AUTHORITY("Trusted Root")
}
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_method("AUTH_GATED", "h", "u", 0, "x509", "Trusted Root"),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access_method("AUTH_GATED", "h", "u", 0, "x509", "Other CA"),
AccessLevel::Read
);
}
#[test]
fn check_access_asl_legacy_path_matches_when_method_empty() {
let acf = r#"
ASG(LEGACY) {
RULE(1, WRITE)
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_asl("LEGACY", "h", "u", 0),
AccessLevel::ReadWrite
);
}
#[test]
fn check_access_method_match_is_case_insensitive() {
let acf = r#"
ASG(MIXED_CASE) {
RULE(1, WRITE) {
METHOD("X509")
}
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_method("MIXED_CASE", "h", "u", 0, "x509", ""),
AccessLevel::ReadWrite
);
}
#[test]
fn empty_rule_asg_denies_access() {
let config = parse_acf("ASG(LOCKED) { }").unwrap();
assert_eq!(
config.check_access("LOCKED", "host", "user"),
AccessLevel::NoAccess,
"ASG with no RULE must deny — C asComputePvt fails closed"
);
}
#[test]
fn unknown_asg_falls_back_to_empty_default_and_denies() {
let config = parse_acf("UAG(ops) { alice }").unwrap();
assert!(config.asg.contains_key("DEFAULT"));
assert_eq!(
config.check_access("TYPO", "host", "alice"),
AccessLevel::NoAccess,
"unknown ASG must resolve to empty DEFAULT ⇒ NoAccess"
);
}
#[test]
fn default_asg_without_rules_denies() {
let config = parse_acf("UAG(ops) { alice }").unwrap();
assert_eq!(
config.check_access("DEFAULT", "host", "alice"),
AccessLevel::NoAccess
);
}
#[test]
fn empty_acf_denies_all_access() {
for acf in ["", "# just a comment\n", "UAG(ops){alice}\nHAG(h){pc1}\n"] {
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("DEFAULT", "host", "alice"),
AccessLevel::NoAccess,
"empty/rule-less ACF must deny (input was {acf:?})"
);
assert_eq!(
config.check_access("ANY_GROUP", "host", "alice"),
AccessLevel::NoAccess,
"unknown ASG against empty ACF must deny (input was {acf:?})"
);
}
}
#[test]
fn handbuilt_config_missing_default_denies() {
let config = AccessSecurityConfig {
uag: HashMap::new(),
hag: HashMap::new(),
asg: HashMap::new(),
unknown_access: AccessLevel::Read,
};
assert_eq!(
config.check_access("WHATEVER", "host", "user"),
AccessLevel::NoAccess
);
}
#[test]
fn rule_none_grants_no_access() {
let config = parse_acf("ASG(N) { RULE(0, NONE) }").unwrap();
assert_eq!(
config.check_access("N", "host", "user"),
AccessLevel::NoAccess
);
}
#[test]
fn rule_unsupported_access_keyword_is_inert() {
let config = parse_acf("ASG(B) { RULE(0, WRIET) }").unwrap();
assert_eq!(config.asg["B"].rules.len(), 1);
assert!(config.asg["B"].rules[0].ignore, "bad keyword ⇒ inert rule");
assert_eq!(
config.check_access("B", "host", "user"),
AccessLevel::NoAccess
);
}
#[test]
fn rule_negative_level_is_rejected() {
let err = parse_acf("ASG(X) { RULE(-1, READ) }");
assert!(err.is_err(), "negative RULE level must fail the parse");
}
#[test]
fn rule_non_numeric_level_is_rejected() {
let err = parse_acf("ASG(X) { RULE(abc, READ) }");
assert!(err.is_err(), "non-numeric RULE level must fail the parse");
}
#[test]
fn unknown_top_level_block_is_skipped_not_fatal() {
let acf = r#"
VENDOR(extension) { whatever }
ASG(DEFAULT) { RULE(1, READ) }
"#;
let config = parse_acf(acf).expect("unknown top-level block must not abort the parse");
assert_eq!(
config.check_access("DEFAULT", "host", "user"),
AccessLevel::Read,
"the ASG after the unknown block must still parse"
);
}
#[test]
fn unknown_well_formed_block_parses_ok_with_warning() {
let acf = r#"
VENDOR(x) { FOO(1) }
ASG(DEFAULT) { RULE(1, READ) }
"#;
let config = parse_acf(acf)
.expect("a well-formed unknown top-level block must warn-and-continue, not fail");
assert_eq!(
config.check_access("DEFAULT", "host", "user"),
AccessLevel::Read
);
}
#[test]
fn unknown_block_bare_head_parses_ok() {
let acf = "VENDOR(x) ASG(DEFAULT) { RULE(1, READ) }";
let config = parse_acf(acf).expect("bare unknown-block head must warn-and-continue");
assert!(config.asg.contains_key("DEFAULT"));
}
#[test]
fn genuine_garbage_acf_is_rejected() {
assert!(
parse_acf("this is not valid ACF (((").is_err(),
"unparseable ACF must fail, not silently skip to EOF"
);
}
#[test]
fn stray_top_level_punctuation_is_rejected() {
assert!(
parse_acf("(((").is_err(),
"a file of only '(((' must fail, not silently skip to EOF"
);
assert!(
parse_acf("}").is_err(),
"a file of only '}}' must fail, not silently skip to EOF"
);
}
#[test]
fn empty_and_comment_only_acf_still_parses_ok() {
assert!(parse_acf("").is_ok(), "empty file must parse Ok");
assert!(
parse_acf(" \n\t \n").is_ok(),
"whitespace-only file must parse Ok"
);
assert!(
parse_acf("# just a comment\n# another\n").is_ok(),
"comment-only file must parse Ok"
);
}
#[test]
fn unknown_keyword_without_paren_head_is_rejected() {
assert!(parse_acf("VENDOR something").is_err());
}
#[test]
fn unknown_keyword_at_eof_is_rejected() {
assert!(parse_acf("VENDOR").is_err());
}
#[test]
fn unknown_block_unbalanced_paren_is_rejected() {
assert!(parse_acf("VENDOR(((").is_err());
}
#[test]
fn unknown_block_unbalanced_brace_is_rejected() {
assert!(parse_acf("VENDOR(x) { unterminated").is_err());
}
#[test]
fn hag_host_match_is_case_insensitive() {
let acf = r#"
HAG(lab) { LabPC1.invalid }
ASG(C) {
RULE(1, WRITE) { HAG(lab) }
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("C", "labpc1.invalid", "user"),
AccessLevel::ReadWrite,
"lowercased HAG entry must match a mixed-case client host"
);
assert_eq!(
config.check_access("C", "LABPC1.INVALID", "user"),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access("C", "other.invalid", "user"),
AccessLevel::Read
);
}
#[test]
fn rule_trapwrite_log_option_parses() {
let config =
parse_acf("ASG(T) { RULE(1, WRITE, TRAPWRITE) RULE(1, READ, NOTRAPWRITE) }").unwrap();
assert_eq!(config.asg["T"].rules.len(), 2);
assert_eq!(config.asg["T"].rules[0].access, RuleAccess::Write);
assert!(
config.asg["T"].rules[0].trap,
"TRAPWRITE must set the trap mask"
);
assert_eq!(config.asg["T"].rules[1].access, RuleAccess::Read);
assert!(
!config.asg["T"].rules[1].trap,
"NOTRAPWRITE must clear the trap mask"
);
}
#[test]
fn rule_bad_log_option_is_rejected() {
assert!(parse_acf("ASG(T) { RULE(1, WRITE, BOGUS) }").is_err());
}
#[test]
fn mr_r20_trap_mask_reflects_matched_rule() {
let cfg = parse_acf(
r#"
ASG(TRAPPED) { RULE(0, WRITE, TRAPWRITE) }
ASG(UNTRAPPED) { RULE(0, WRITE, NOTRAPWRITE) }
ASG(PLAIN) { RULE(0, WRITE) }
ASG(LOCKED) { }
"#,
)
.unwrap();
let (lvl, trap) = cfg.check_access_method_trap("TRAPPED", "h", "u", 0, "", "");
assert_eq!(lvl, AccessLevel::ReadWrite);
assert!(trap, "a TRAPWRITE rule must resolve rule_was_trap = true");
let (lvl, trap) = cfg.check_access_method_trap("UNTRAPPED", "h", "u", 0, "", "");
assert_eq!(lvl, AccessLevel::ReadWrite);
assert!(
!trap,
"a NOTRAPWRITE rule must resolve rule_was_trap = false"
);
let (lvl, trap) = cfg.check_access_method_trap("PLAIN", "h", "u", 0, "", "");
assert_eq!(lvl, AccessLevel::ReadWrite);
assert!(
!trap,
"a rule with no trap option must resolve rule_was_trap = false"
);
let (lvl, trap) = cfg.check_access_method_trap("LOCKED", "h", "u", 0, "", "");
assert_eq!(lvl, AccessLevel::NoAccess);
assert!(
!trap,
"a denied resolution must carry rule_was_trap = false"
);
}
#[test]
fn mr_r20_trap_mask_follows_last_access_raising_rule() {
let cfg = parse_acf("ASG(M) { RULE(0, READ) RULE(0, WRITE, TRAPWRITE) }").unwrap();
let (lvl, trap) = cfg.check_access_method_trap("M", "h", "u", 0, "", "");
assert_eq!(lvl, AccessLevel::ReadWrite);
assert!(
trap,
"trap mask must follow the WRITE rule that raised access"
);
let cfg = parse_acf("ASG(N) { RULE(0, READ) RULE(0, WRITE, NOTRAPWRITE) }").unwrap();
let (lvl, trap) = cfg.check_access_method_trap("N", "h", "u", 0, "", "");
assert_eq!(lvl, AccessLevel::ReadWrite);
assert!(!trap, "NOTRAPWRITE on the access-raising rule must win");
}
#[test]
fn calc_rule_is_disabled_when_unevaluable() {
let config = parse_acf(r#"ASG(G) { INPA("ref") RULE(1, WRITE) { CALC("A=1") } }"#).unwrap();
let rule = &config.asg["G"].rules[0];
assert!(rule.calc.is_some(), "CALC clause must be parsed and stored");
assert!(
rule.ignore,
"an unevaluable CALC rule must be disabled, not unconditional"
);
assert_eq!(
config.check_access("G", "host", "user"),
AccessLevel::NoAccess,
"CALC rule must not silently grant WRITE"
);
}
#[test]
fn calc_rule_with_bad_expression_is_rejected() {
assert!(
parse_acf(r#"ASG(G) { RULE(1, WRITE) { CALC("A=") } }"#).is_err(),
"syntactically broken CALC must fail the parse"
);
}
#[test]
fn asg_inp_links_are_parsed() {
let acf = r#"
ASG(G) {
INPA("rec1.VAL")
INPC("rec3.VAL")
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
let inp = &config.asg["G"].inp;
assert_eq!(inp.len(), 2);
assert_eq!(inp[0].index, 0);
assert_eq!(inp[0].link, "rec1.VAL");
assert_eq!(inp[1].index, 2);
assert_eq!(inp[1].link, "rec3.VAL");
}
#[test]
fn asg_inp_bad_selector_is_rejected() {
assert!(parse_acf(r#"ASG(G) { INPZ("x") }"#).is_err());
}
#[test]
fn paren_name_rejects_embedded_whitespace() {
assert!(parse_acf("UAG(my group) { x }").is_err());
}
#[test]
fn paren_name_rejects_unterminated() {
assert!(parse_acf("UAG(unterminated").is_err());
}
#[test]
fn paren_name_accepts_quoted_form() {
let config = parse_acf(r#"UAG("my group") { x }"#).unwrap();
assert!(config.uag.contains_key("my group"));
}
#[test]
fn asl_gate_still_honoured_after_fail_closed_rewrite() {
let config = parse_acf("ASG(A) { RULE(0, READ) RULE(1, WRITE) }").unwrap();
assert_eq!(
config.check_access_method("A", "h", "u", 0, "", ""),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access_method("A", "h", "u", 2, "", ""),
AccessLevel::NoAccess
);
}
}