use std::collections::HashMap;
use crate::error::{CaError, CaResult};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AccessLevel {
NoAccess,
Read,
ReadWrite,
}
#[derive(Debug, Clone)]
pub struct AccessChecked {
pv_name: String,
level: AccessLevel,
_seal: AccessSeal,
}
#[derive(Debug, Clone)]
struct AccessSeal;
impl AccessChecked {
pub fn pv_name(&self) -> &str {
&self.pv_name
}
pub fn level(&self) -> AccessLevel {
self.level
}
pub fn allows_read(&self) -> bool {
!matches!(self.level, AccessLevel::NoAccess)
}
pub fn allows_write(&self) -> bool {
matches!(self.level, AccessLevel::ReadWrite)
}
}
pub struct AccessGate {
inner: AccessGateInner,
acl_version: AclVersionSource,
}
#[derive(Clone)]
enum AclVersionSource {
Atomic(std::sync::Arc<std::sync::atomic::AtomicU64>),
Aggregator(std::sync::Arc<dyn Fn() -> u64 + Send + Sync>),
}
pub type AsgAslResolver = std::sync::Arc<
dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = (String, u8)> + Send>>
+ Send
+ Sync,
>;
enum AccessGateInner {
Required {
acf: std::sync::Arc<tokio::sync::RwLock<Option<AccessSecurityConfig>>>,
resolver: AsgAslResolver,
},
Open,
}
impl AccessGate {
pub fn required(
acf: std::sync::Arc<tokio::sync::RwLock<Option<AccessSecurityConfig>>>,
resolver: AsgAslResolver,
) -> Self {
Self::required_with_version(
acf,
resolver,
std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
)
}
pub fn required_with_version(
acf: std::sync::Arc<tokio::sync::RwLock<Option<AccessSecurityConfig>>>,
resolver: AsgAslResolver,
acl_version: std::sync::Arc<std::sync::atomic::AtomicU64>,
) -> Self {
Self {
inner: AccessGateInner::Required { acf, resolver },
acl_version: AclVersionSource::Atomic(acl_version),
}
}
pub fn open() -> Self {
Self {
inner: AccessGateInner::Open,
acl_version: AclVersionSource::Atomic(std::sync::Arc::new(
std::sync::atomic::AtomicU64::new(0),
)),
}
}
pub fn open_with_aggregator(f: std::sync::Arc<dyn Fn() -> u64 + Send + Sync>) -> Self {
Self {
inner: AccessGateInner::Open,
acl_version: AclVersionSource::Aggregator(f),
}
}
pub fn acl_version(&self) -> u64 {
match &self.acl_version {
AclVersionSource::Atomic(a) => a.load(std::sync::atomic::Ordering::Acquire),
AclVersionSource::Aggregator(f) => f(),
}
}
pub fn bump_acl_version(&self) {
if let AclVersionSource::Atomic(a) = &self.acl_version {
a.fetch_add(1, std::sync::atomic::Ordering::Release);
}
}
pub async fn check(
&self,
pv_name: impl Into<String>,
host: &str,
user: &str,
method: &str,
authority: &str,
) -> AccessChecked {
let pv_name = pv_name.into();
let level = match &self.inner {
AccessGateInner::Open => AccessLevel::ReadWrite,
AccessGateInner::Required { acf, resolver } => {
let guard = acf.read().await;
match *guard {
None => AccessLevel::ReadWrite,
Some(ref cfg) => {
let (asg, asl) = resolver(pv_name.clone()).await;
cfg.check_access_method(&asg, host, user, asl, method, authority)
}
}
}
};
AccessChecked {
pv_name,
level,
_seal: AccessSeal,
}
}
}
#[cfg(test)]
mod access_checked_tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn open_gate_grants_read_write() {
let gate = AccessGate::open();
let checked = gate.check("any:pv", "h", "u", "anonymous", "").await;
assert_eq!(checked.level(), AccessLevel::ReadWrite);
assert!(checked.allows_read());
assert!(checked.allows_write());
assert_eq!(checked.pv_name(), "any:pv");
}
#[tokio::test]
async fn required_gate_with_no_acf_attached_is_permissive() {
let cell = Arc::new(tokio::sync::RwLock::new(None));
let resolver: AsgAslResolver =
Arc::new(|_pv| Box::pin(async { ("DEFAULT".to_string(), 0u8) }));
let gate = AccessGate::required(cell, resolver);
let checked = gate.check("any:pv", "h", "u", "anonymous", "").await;
assert_eq!(checked.level(), AccessLevel::ReadWrite);
}
#[tokio::test]
async fn required_gate_with_acf_denies_unprivileged_peer() {
let cfg = parse_acf(
r#"
UAG(ops) { alice }
ASG(DEFAULT) {
RULE(0, READ) { UAG(ops) }
}
"#,
)
.unwrap();
let cell = Arc::new(tokio::sync::RwLock::new(Some(cfg)));
let resolver: AsgAslResolver =
Arc::new(|_pv| Box::pin(async { ("DEFAULT".to_string(), 0u8) }));
let gate = AccessGate::required(cell, resolver);
let allowed = gate.check("x", "h", "alice", "anonymous", "").await;
assert!(allowed.allows_read());
assert!(!allowed.allows_write());
let denied = gate.check("x", "h", "intruder", "anonymous", "").await;
assert_eq!(denied.level(), AccessLevel::NoAccess);
assert!(!denied.allows_read());
}
}
#[derive(Debug, Clone, Default)]
pub struct AccessRule {
pub level: u8,
pub write: bool, pub uag: Vec<String>,
pub hag: Vec<String>,
pub method: Vec<String>,
pub authority: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct AccessSecurityGroup {
pub rules: Vec<AccessRule>,
}
#[derive(Debug, Clone)]
pub struct AccessSecurityConfig {
pub uag: HashMap<String, Vec<String>>,
pub hag: HashMap<String, Vec<String>>,
pub asg: HashMap<String, AccessSecurityGroup>,
pub unknown_access: AccessLevel,
}
impl AccessSecurityConfig {
pub fn check_access(&self, asg_name: &str, host: &str, user: &str) -> AccessLevel {
self.check_access_asl(asg_name, host, user, 0)
}
pub fn check_access_method(
&self,
asg_name: &str,
host: &str,
user: &str,
record_asl: u8,
method: &str,
authority: &str,
) -> AccessLevel {
let asg = match self.asg.get(asg_name) {
Some(a) => a,
None => match self.asg.get("DEFAULT") {
Some(a) => a,
None => return AccessLevel::ReadWrite,
},
};
if asg.rules.is_empty() {
return AccessLevel::ReadWrite;
}
if user.is_empty() || host.is_empty() {
return self.unknown_access;
}
let mut can_read = false;
let mut can_write = false;
for rule in &asg.rules {
if record_asl > rule.level {
continue;
}
let user_match = rule.uag.is_empty()
|| rule.uag.iter().any(|g| {
self.uag
.get(g)
.map(|members| members.iter().any(|m| m == user))
.unwrap_or(false)
});
let host_match = rule.hag.is_empty()
|| rule.hag.iter().any(|g| {
self.hag
.get(g)
.map(|members| members.iter().any(|m| m == host))
.unwrap_or(false)
});
let method_match = rule.method.is_empty()
|| rule.method.iter().any(|m| m.eq_ignore_ascii_case(method));
let authority_match = rule.authority.is_empty()
|| rule
.authority
.iter()
.any(|a| a.eq_ignore_ascii_case(authority));
if user_match && host_match && method_match && authority_match {
if rule.write {
can_write = true;
can_read = true;
} else {
can_read = true;
}
}
}
match (can_read, can_write) {
(_, true) => AccessLevel::ReadWrite,
(true, false) => AccessLevel::Read,
_ => AccessLevel::NoAccess,
}
}
pub fn check_access_asl(
&self,
asg_name: &str,
host: &str,
user: &str,
record_asl: u8,
) -> AccessLevel {
self.check_access_method(asg_name, host, user, record_asl, "", "")
}
}
pub fn parse_acf(content: &str) -> CaResult<AccessSecurityConfig> {
let mut config = AccessSecurityConfig {
uag: HashMap::new(),
hag: HashMap::new(),
asg: HashMap::new(),
unknown_access: AccessLevel::Read,
};
let mut chars = content.chars().peekable();
let mut buf = String::new();
while chars.peek().is_some() {
skip_ws_comments(&mut chars);
buf.clear();
read_word(&mut chars, &mut buf);
match buf.as_str() {
"UAG" => {
let name = read_paren_name(&mut chars)?;
let members = read_brace_list(&mut chars)?;
config.uag.insert(name, members);
}
"HAG" => {
let name = read_paren_name(&mut chars)?;
let members = read_brace_list(&mut chars)?;
let expanded = expand_hag_members(&members);
config.hag.insert(name, expanded);
}
"ASG" => {
let name = read_paren_name(&mut chars)?;
let asg = parse_asg_body(&mut chars)?;
config.asg.insert(name, asg);
}
"" => break,
other => {
return Err(CaError::Protocol(format!(
"ACF: unexpected keyword '{other}'"
)));
}
}
}
Ok(config)
}
fn expand_hag_members(members: &[String]) -> Vec<String> {
use std::net::ToSocketAddrs;
let mut out: Vec<String> = Vec::with_capacity(members.len());
for m in members {
out.push(m.clone());
if m.parse::<std::net::IpAddr>().is_ok() {
continue;
}
match format!("{m}:0").to_socket_addrs() {
Ok(iter) => {
for sa in iter {
let ip = sa.ip().to_string();
if !out.iter().any(|s| s == &ip) {
out.push(ip);
}
}
}
Err(e) => {
tracing::debug!(
target: "epics_base_rs::access_security",
host = %m,
error = %e,
"ACF HAG: DNS lookup failed; keeping literal entry (libcom 932e9f3 soft fallback)"
);
}
}
}
out
}
fn skip_ws_comments(chars: &mut std::iter::Peekable<std::str::Chars>) {
while let Some(&c) = chars.peek() {
if c.is_whitespace() {
chars.next();
} else if c == '#' {
while let Some(&c) = chars.peek() {
chars.next();
if c == '\n' {
break;
}
}
} else {
break;
}
}
}
fn read_word(chars: &mut std::iter::Peekable<std::str::Chars>, buf: &mut String) {
while let Some(&c) = chars.peek() {
if c.is_alphanumeric() || c == '_' {
buf.push(c);
chars.next();
} else {
break;
}
}
}
fn read_paren_name(chars: &mut std::iter::Peekable<std::str::Chars>) -> CaResult<String> {
skip_ws_comments(chars);
if chars.next() != Some('(') {
return Err(CaError::Protocol("ACF: expected '('".into()));
}
skip_ws_comments(chars);
let mut name = String::new();
while let Some(&c) = chars.peek() {
if c == ')' {
chars.next();
break;
}
if !c.is_whitespace() {
name.push(c);
}
chars.next();
}
Ok(name)
}
fn read_brace_list(chars: &mut std::iter::Peekable<std::str::Chars>) -> CaResult<Vec<String>> {
skip_ws_comments(chars);
if chars.next() != Some('{') {
return Err(CaError::Protocol("ACF: expected '{'".into()));
}
let mut items = Vec::new();
let mut current = String::new();
loop {
skip_ws_comments(chars);
match chars.peek() {
Some(&'}') => {
chars.next();
break;
}
Some(&',') => {
chars.next();
if !current.is_empty() {
items.push(current.clone());
current.clear();
}
}
Some(&c) if c.is_alphanumeric() || c == '_' || c == '.' || c == '-' => {
current.push(c);
chars.next();
}
Some(_) => {
chars.next();
}
None => return Err(CaError::Protocol("ACF: unterminated '{'".into())),
}
}
if !current.is_empty() {
items.push(current);
}
Ok(items)
}
fn parse_asg_body(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> CaResult<AccessSecurityGroup> {
skip_ws_comments(chars);
if chars.next() != Some('{') {
return Err(CaError::Protocol("ACF: expected '{' after ASG name".into()));
}
let mut asg = AccessSecurityGroup::default();
loop {
skip_ws_comments(chars);
match chars.peek() {
Some(&'}') => {
chars.next();
break;
}
Some(_) => {
let mut kw = String::new();
read_word(chars, &mut kw);
if kw == "RULE" {
let rule = parse_rule(chars)?;
asg.rules.push(rule);
} else if kw.is_empty() {
chars.next(); }
}
None => return Err(CaError::Protocol("ACF: unterminated ASG".into())),
}
}
Ok(asg)
}
fn parse_rule(chars: &mut std::iter::Peekable<std::str::Chars>) -> CaResult<AccessRule> {
skip_ws_comments(chars);
if chars.next() != Some('(') {
return Err(CaError::Protocol("ACF: expected '(' after RULE".into()));
}
skip_ws_comments(chars);
let mut level_str = String::new();
while let Some(&c) = chars.peek() {
if c.is_ascii_digit() {
level_str.push(c);
chars.next();
} else {
break;
}
}
let level: u8 = level_str.parse().unwrap_or(1);
skip_ws_comments(chars);
if chars.peek() == Some(&',') {
chars.next();
}
skip_ws_comments(chars);
let mut access_str = String::new();
read_word(chars, &mut access_str);
let write = access_str.eq_ignore_ascii_case("WRITE");
skip_ws_comments(chars);
if chars.peek() == Some(&')') {
chars.next();
}
let mut uag = Vec::new();
let mut hag = Vec::new();
let mut method = Vec::new();
let mut authority = Vec::new();
skip_ws_comments(chars);
if chars.peek() == Some(&'{') {
chars.next();
loop {
skip_ws_comments(chars);
match chars.peek() {
Some(&'}') => {
chars.next();
break;
}
Some(_) => {
let mut kw = String::new();
read_word(chars, &mut kw);
if kw == "UAG" {
let name = read_paren_name(chars)?;
uag.push(name);
} else if kw == "HAG" {
let name = read_paren_name(chars)?;
hag.push(name);
} else if kw == "METHOD" {
method.extend(read_paren_string_list(chars)?);
} else if kw == "AUTHORITY" {
authority.extend(read_paren_string_list(chars)?);
} else if kw.is_empty() {
chars.next();
}
}
None => break,
}
}
}
Ok(AccessRule {
level,
write,
uag,
hag,
method,
authority,
})
}
fn read_paren_string_list(
chars: &mut std::iter::Peekable<std::str::Chars>,
) -> CaResult<Vec<String>> {
skip_ws_comments(chars);
if chars.next() != Some('(') {
return Err(CaError::Protocol(
"ACF: expected '(' after METHOD/AUTHORITY".into(),
));
}
let mut items = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
loop {
match chars.peek() {
Some(&'"') => {
chars.next();
in_quotes = !in_quotes;
}
Some(&')') if !in_quotes => {
chars.next();
break;
}
Some(&',') if !in_quotes => {
chars.next();
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
items.push(trimmed);
}
current.clear();
}
Some(&c) => {
current.push(c);
chars.next();
}
None => {
return Err(CaError::Protocol(
"ACF: unterminated METHOD/AUTHORITY list".into(),
));
}
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
items.push(trimmed);
}
Ok(items)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_acf_basic() {
let acf = r#"
UAG(admins) { user1, user2 }
HAG(operators) { host1, host2 }
ASG(DEFAULT) {
RULE(1, WRITE) { UAG(admins) HAG(operators) }
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(config.uag.get("admins").unwrap(), &["user1", "user2"]);
assert_eq!(config.hag.get("operators").unwrap(), &["host1", "host2"]);
assert!(config.asg.contains_key("DEFAULT"));
assert_eq!(config.asg["DEFAULT"].rules.len(), 2);
}
#[test]
fn test_parse_acf_hag_uag() {
let acf = r#"
UAG(ops) { alice, bob }
HAG(lab) { lab-pc1.invalid }
ASG(SECURE) {
RULE(1, WRITE) { UAG(ops) HAG(lab) }
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(config.uag["ops"], vec!["alice", "bob"]);
assert_eq!(config.hag["lab"], vec!["lab-pc1.invalid"]);
}
#[test]
fn hag_dns_resolution_appends_ip_for_match() {
let acf = r#"
HAG(local) { localhost }
ASG(DEFAULT) {
RULE(1, WRITE) { HAG(local) }
}
"#;
let config = parse_acf(acf).unwrap();
let entries = &config.hag["local"];
assert!(
entries.contains(&"localhost".to_string()),
"literal hostname always preserved"
);
assert!(
entries.iter().any(|s| s == "127.0.0.1"),
"resolved IPv4 appended for IP-presenting peers"
);
}
#[test]
fn hag_unresolvable_name_does_not_abort_parser() {
let acf = r#"
HAG(quarantine) { gone.invalid, alive.invalid }
ASG(DEFAULT) {
RULE(1, WRITE) { HAG(quarantine) }
}
"#;
let config = parse_acf(acf).expect("parser must not abort on bad DNS");
let entries = &config.hag["quarantine"];
assert_eq!(
entries.len(),
2,
"literal entries preserved verbatim; no resolved IPs appended"
);
assert_eq!(entries[0], "gone.invalid");
assert_eq!(entries[1], "alive.invalid");
}
#[test]
fn test_check_access_default_rw() {
let acf = "ASG(DEFAULT) { RULE(1, WRITE) RULE(1, READ) }";
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("DEFAULT", "host1", "user1"),
AccessLevel::ReadWrite
);
}
#[test]
fn test_check_access_read_only() {
let acf = r#"
UAG(admins) { admin1 }
ASG(READONLY) {
RULE(1, READ)
RULE(1, WRITE) { UAG(admins) }
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("READONLY", "host1", "admin1"),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access("READONLY", "host1", "regular"),
AccessLevel::Read
);
}
#[test]
fn test_check_access_hag_uag_match() {
let acf = r#"
UAG(ops) { alice }
HAG(lab) { lab-pc1 }
ASG(CONTROLLED) {
RULE(1, WRITE) { UAG(ops) HAG(lab) }
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access("CONTROLLED", "lab-pc1", "alice"),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access("CONTROLLED", "other-host", "alice"),
AccessLevel::Read
);
assert_eq!(
config.check_access("CONTROLLED", "lab-pc1", "bob"),
AccessLevel::Read
);
}
#[test]
fn test_check_access_unknown_user() {
let acf = r#"
ASG(DEFAULT) {
RULE(1, WRITE)
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(config.check_access("DEFAULT", "", ""), AccessLevel::Read);
}
#[test]
fn parse_acf_captures_method_and_authority() {
let acf = r#"
ASG(SECURE) {
RULE(1, WRITE) {
METHOD("ca", "x509")
AUTHORITY("ANL CA")
}
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
let asg = &config.asg["SECURE"];
assert_eq!(asg.rules.len(), 2);
assert_eq!(asg.rules[0].method, vec!["ca", "x509"]);
assert_eq!(asg.rules[0].authority, vec!["ANL CA"]);
assert!(
asg.rules[1].method.is_empty(),
"READ rule must not inherit METHOD list",
);
assert!(asg.rules[1].authority.is_empty());
}
#[test]
fn tls_x509_acf_rule_grants_write_on_issuer_match() {
let cfg = parse_acf(
r#"
ASG(TLS_ONLY) {
RULE(1, WRITE) { METHOD("x509") AUTHORITY("CN=ops-ca, O=Lab") }
RULE(1, READ)
}
"#,
)
.unwrap();
assert_eq!(
cfg.check_access_method("TLS_ONLY", "h", "u", 0, "", ""),
AccessLevel::Read
);
assert_eq!(
cfg.check_access_method("TLS_ONLY", "h", "u", 0, "x509", "CN=other-ca"),
AccessLevel::Read
);
assert_eq!(
cfg.check_access_method("TLS_ONLY", "h", "u", 0, "x509", "CN=ops-ca, O=Lab"),
AccessLevel::ReadWrite
);
}
#[test]
fn check_access_method_gates_on_method() {
let acf = r#"
ASG(METHOD_GATED) {
RULE(1, WRITE) {
METHOD("x509")
}
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_method("METHOD_GATED", "h", "u", 0, "x509", ""),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access_method("METHOD_GATED", "h", "u", 0, "ca", ""),
AccessLevel::Read
);
}
#[test]
fn check_access_method_gates_on_authority() {
let acf = r#"
ASG(AUTH_GATED) {
RULE(1, WRITE) {
AUTHORITY("Trusted Root")
}
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_method("AUTH_GATED", "h", "u", 0, "x509", "Trusted Root"),
AccessLevel::ReadWrite
);
assert_eq!(
config.check_access_method("AUTH_GATED", "h", "u", 0, "x509", "Other CA"),
AccessLevel::Read
);
}
#[test]
fn check_access_asl_legacy_path_matches_when_method_empty() {
let acf = r#"
ASG(LEGACY) {
RULE(1, WRITE)
RULE(1, READ)
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_asl("LEGACY", "h", "u", 0),
AccessLevel::ReadWrite
);
}
#[test]
fn check_access_method_match_is_case_insensitive() {
let acf = r#"
ASG(MIXED_CASE) {
RULE(1, WRITE) {
METHOD("X509")
}
}
"#;
let config = parse_acf(acf).unwrap();
assert_eq!(
config.check_access_method("MIXED_CASE", "h", "u", 0, "x509", ""),
AccessLevel::ReadWrite
);
}
}