use std::sync::Arc;
use epics_pva_rs::pvdata::{FieldDesc, PvField};
use epics_pva_rs::server_native::ChannelContext;
use epics_pva_rs::server_native::source::{ChannelSource, RawMonitorEvent};
use tokio::sync::mpsc;
pub trait Layer<S: ChannelSource>: Send + Sync + 'static {
type Wrapped: ChannelSource;
fn layer(self, inner: S) -> Self::Wrapped;
}
pub struct ReadOnlyLayer;
pub struct ReadOnly<S> {
inner: Arc<S>,
}
impl<S: ChannelSource> Layer<S> for ReadOnlyLayer {
type Wrapped = ReadOnly<S>;
fn layer(self, inner: S) -> ReadOnly<S> {
ReadOnly {
inner: Arc::new(inner),
}
}
}
impl<S: ChannelSource> ChannelSource for ReadOnly<S> {
async fn list_pvs(&self) -> Vec<String> {
self.inner.list_pvs().await
}
async fn has_pv(&self, name: &str) -> bool {
self.inner.has_pv(name).await
}
async fn get_introspection(&self, name: &str) -> Option<FieldDesc> {
self.inner.get_introspection(name).await
}
async fn get_value(&self, name: &str) -> Option<PvField> {
self.inner.get_value(name).await
}
async fn put_value(&self, _name: &str, _value: PvField) -> Result<(), String> {
Err("read-only mode: PUT rejected".into())
}
async fn put_value_ctx(
&self,
_name: &str,
_value: PvField,
_ctx: ChannelContext,
) -> Result<(), String> {
Err("read-only mode: PUT rejected".into())
}
async fn is_writable(&self, _name: &str) -> bool {
false
}
async fn subscribe(&self, name: &str) -> Option<mpsc::Receiver<PvField>> {
self.inner.subscribe(name).await
}
async fn subscribe_raw(&self, name: &str) -> Option<mpsc::Receiver<RawMonitorEvent>> {
self.inner.subscribe_raw(name).await
}
async fn rpc(
&self,
name: &str,
request_desc: FieldDesc,
request_value: PvField,
) -> Result<(FieldDesc, PvField), String> {
self.inner.rpc(name, request_desc, request_value).await
}
fn notify_watermark_high(&self, name: &str) {
self.inner.notify_watermark_high(name);
}
fn notify_watermark_low(&self, name: &str) {
self.inner.notify_watermark_low(name);
}
}
#[derive(Clone, Default)]
pub struct AclConfig {
pub deny: Vec<String>,
pub allow_only: Vec<String>,
}
impl AclConfig {
pub fn allowed(&self, name: &str) -> bool {
if self.deny.iter().any(|p| matches_pattern(p, name)) {
return false;
}
if !self.allow_only.is_empty() && !self.allow_only.iter().any(|p| matches_pattern(p, name))
{
return false;
}
true
}
}
fn matches_pattern(pattern: &str, name: &str) -> bool {
if let Some(prefix) = pattern.strip_suffix('*') {
return name.starts_with(prefix);
}
if let Some(suffix) = pattern.strip_prefix('*') {
return name.ends_with(suffix);
}
name == pattern
}
pub struct AclLayer {
config: AclConfig,
}
impl AclLayer {
pub fn new(config: AclConfig) -> Self {
Self { config }
}
}
pub struct Acl<S> {
inner: Arc<S>,
config: AclConfig,
}
impl<S: ChannelSource> Layer<S> for AclLayer {
type Wrapped = Acl<S>;
fn layer(self, inner: S) -> Acl<S> {
Acl {
inner: Arc::new(inner),
config: self.config,
}
}
}
impl<S: ChannelSource> ChannelSource for Acl<S> {
async fn list_pvs(&self) -> Vec<String> {
let mut names = self.inner.list_pvs().await;
names.retain(|n| self.config.allowed(n));
names
}
async fn has_pv(&self, name: &str) -> bool {
if !self.config.allowed(name) {
return false;
}
self.inner.has_pv(name).await
}
async fn get_introspection(&self, name: &str) -> Option<FieldDesc> {
if !self.config.allowed(name) {
return None;
}
self.inner.get_introspection(name).await
}
async fn get_value(&self, name: &str) -> Option<PvField> {
if !self.config.allowed(name) {
return None;
}
self.inner.get_value(name).await
}
async fn put_value(&self, name: &str, value: PvField) -> Result<(), String> {
if !self.config.allowed(name) {
return Err(format!("ACL: PV '{name}' denied"));
}
self.inner.put_value(name, value).await
}
async fn put_value_ctx(
&self,
name: &str,
value: PvField,
ctx: ChannelContext,
) -> Result<(), String> {
if !self.config.allowed(name) {
return Err(format!("ACL: PV '{name}' denied"));
}
self.inner.put_value_ctx(name, value, ctx).await
}
async fn is_writable(&self, name: &str) -> bool {
self.config.allowed(name) && self.inner.is_writable(name).await
}
async fn subscribe(&self, name: &str) -> Option<mpsc::Receiver<PvField>> {
if !self.config.allowed(name) {
return None;
}
self.inner.subscribe(name).await
}
async fn subscribe_raw(&self, name: &str) -> Option<mpsc::Receiver<RawMonitorEvent>> {
if !self.config.allowed(name) {
return None;
}
self.inner.subscribe_raw(name).await
}
async fn rpc(
&self,
name: &str,
request_desc: FieldDesc,
request_value: PvField,
) -> Result<(FieldDesc, PvField), String> {
if !self.config.allowed(name) {
return Err(format!("ACL: PV '{name}' denied"));
}
self.inner.rpc(name, request_desc, request_value).await
}
fn notify_watermark_high(&self, name: &str) {
self.inner.notify_watermark_high(name);
}
fn notify_watermark_low(&self, name: &str) {
self.inner.notify_watermark_low(name);
}
}
pub trait AuditSink: Send + Sync + 'static {
fn record(&self, event: AuditEvent);
}
pub struct NoopAudit;
impl AuditSink for NoopAudit {
fn record(&self, _event: AuditEvent) {}
}
pub struct ClosureAudit<F: Fn(AuditEvent) + Send + Sync + 'static>(pub F);
impl<F: Fn(AuditEvent) + Send + Sync + 'static> AuditSink for ClosureAudit<F> {
fn record(&self, event: AuditEvent) {
(self.0)(event);
}
}
pub struct MpscAuditSink {
tx: tokio::sync::mpsc::Sender<AuditEvent>,
drops: Arc<std::sync::atomic::AtomicU64>,
}
impl MpscAuditSink {
pub fn wrap<A: AuditSink>(capacity: usize, inner: A) -> Self {
let (tx, mut rx) = tokio::sync::mpsc::channel::<AuditEvent>(capacity.max(1));
tokio::spawn(async move {
while let Some(ev) = rx.recv().await {
inner.record(ev);
}
});
Self {
tx,
drops: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
pub fn drops(&self) -> u64 {
self.drops.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl AuditSink for MpscAuditSink {
fn record(&self, event: AuditEvent) {
if self.tx.try_send(event).is_err() {
self.drops
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
}
#[derive(Debug, Clone)]
pub struct AuditEvent {
pub pv: String,
pub event: AuditEventKind,
pub user: String,
pub host: String,
pub result: AuditResult,
pub timestamp: std::time::SystemTime,
pub error: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuditEventKind {
Put,
Get,
Subscribe,
Rpc,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuditResult {
Ok,
Denied,
Failed,
}
fn make_audit_event(name: &str, user: &str, host: &str, result: &Result<(), String>) -> AuditEvent {
let (kind, error) = match result {
Ok(_) => (AuditResult::Ok, String::new()),
Err(msg) => {
let lower = msg.to_lowercase();
if lower.contains("deny")
|| lower.contains("denied")
|| lower.contains("acl:")
|| lower.contains("read-only")
{
(AuditResult::Denied, msg.clone())
} else {
(AuditResult::Failed, msg.clone())
}
}
};
AuditEvent {
pv: name.to_string(),
event: AuditEventKind::Put,
user: user.to_string(),
host: host.to_string(),
result: kind,
timestamp: std::time::SystemTime::now(),
error,
}
}
pub struct AuditLayer<A: AuditSink> {
sink: Arc<A>,
audit_get: bool,
audit_subscribe: bool,
audit_rpc: bool,
}
impl<A: AuditSink> AuditLayer<A> {
pub fn new(sink: A) -> Self {
Self {
sink: Arc::new(sink),
audit_get: false,
audit_subscribe: false,
audit_rpc: false,
}
}
}
impl AuditLayer<MpscAuditSink> {
pub fn with_blocking_sink<I: AuditSink>(capacity: usize, inner: I) -> Self {
Self {
sink: Arc::new(MpscAuditSink::wrap(capacity, inner)),
audit_get: false,
audit_subscribe: false,
audit_rpc: false,
}
}
}
impl<A: AuditSink> AuditLayer<A> {
pub fn with_get(mut self) -> Self {
self.audit_get = true;
self
}
pub fn with_subscribe(mut self) -> Self {
self.audit_subscribe = true;
self
}
pub fn with_rpc(mut self) -> Self {
self.audit_rpc = true;
self
}
}
pub struct Audited<S, A> {
inner: Arc<S>,
sink: Arc<A>,
audit_get: bool,
audit_subscribe: bool,
audit_rpc: bool,
}
impl<S: ChannelSource, A: AuditSink> Layer<S> for AuditLayer<A> {
type Wrapped = Audited<S, A>;
fn layer(self, inner: S) -> Audited<S, A> {
Audited {
inner: Arc::new(inner),
sink: self.sink,
audit_get: self.audit_get,
audit_subscribe: self.audit_subscribe,
audit_rpc: self.audit_rpc,
}
}
}
impl<S: ChannelSource, A: AuditSink> ChannelSource for Audited<S, A> {
async fn list_pvs(&self) -> Vec<String> {
self.inner.list_pvs().await
}
async fn has_pv(&self, name: &str) -> bool {
self.inner.has_pv(name).await
}
async fn get_introspection(&self, name: &str) -> Option<FieldDesc> {
self.inner.get_introspection(name).await
}
async fn get_value(&self, name: &str) -> Option<PvField> {
let result = self.inner.get_value(name).await;
if self.audit_get {
let outcome: Result<(), String> = if result.is_some() {
Ok(())
} else {
Err(format!("PV '{name}' not found"))
};
let mut ev = make_audit_event(name, "", "", &outcome);
ev.event = AuditEventKind::Get;
self.sink.record(ev);
}
result
}
async fn put_value(&self, name: &str, value: PvField) -> Result<(), String> {
let result = self.inner.put_value(name, value).await;
self.sink.record(make_audit_event(name, "", "", &result));
result
}
async fn put_value_ctx(
&self,
name: &str,
value: PvField,
ctx: ChannelContext,
) -> Result<(), String> {
let user = ctx.account.clone();
let host = ctx.host.clone();
let result = self.inner.put_value_ctx(name, value, ctx).await;
self.sink
.record(make_audit_event(name, &user, &host, &result));
result
}
async fn is_writable(&self, name: &str) -> bool {
self.inner.is_writable(name).await
}
async fn subscribe(&self, name: &str) -> Option<mpsc::Receiver<PvField>> {
let result = self.inner.subscribe(name).await;
if self.audit_subscribe {
let outcome: Result<(), String> = if result.is_some() {
Ok(())
} else {
Err(format!("PV '{name}' not subscribable"))
};
let mut ev = make_audit_event(name, "", "", &outcome);
ev.event = AuditEventKind::Subscribe;
self.sink.record(ev);
}
result
}
async fn subscribe_raw(&self, name: &str) -> Option<mpsc::Receiver<RawMonitorEvent>> {
let result = self.inner.subscribe_raw(name).await;
if self.audit_subscribe {
let outcome: Result<(), String> = if result.is_some() {
Ok(())
} else {
Err(format!("PV '{name}' not subscribable (raw)"))
};
let mut ev = make_audit_event(name, "", "", &outcome);
ev.event = AuditEventKind::Subscribe;
self.sink.record(ev);
}
result
}
async fn rpc(
&self,
name: &str,
request_desc: FieldDesc,
request_value: PvField,
) -> Result<(FieldDesc, PvField), String> {
let result = self.inner.rpc(name, request_desc, request_value).await;
if self.audit_rpc {
let outcome: Result<(), String> = match &result {
Ok(_) => Ok(()),
Err(e) => Err(e.clone()),
};
let mut ev = make_audit_event(name, "", "", &outcome);
ev.event = AuditEventKind::Rpc;
self.sink.record(ev);
}
result
}
fn notify_watermark_high(&self, name: &str) {
self.inner.notify_watermark_high(name);
}
fn notify_watermark_low(&self, name: &str) {
self.inner.notify_watermark_low(name);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pattern_matching() {
assert!(matches_pattern("MOTOR:*", "MOTOR:VAL"));
assert!(matches_pattern("*VAL", "MOTOR:VAL"));
assert!(matches_pattern("EXACT", "EXACT"));
assert!(!matches_pattern("EXACT", "EXACT2"));
assert!(!matches_pattern("MOTOR:*", "OTHER:VAL"));
}
#[test]
fn acl_allow_only() {
let cfg = AclConfig {
allow_only: vec!["BL10C:*".into()],
..Default::default()
};
assert!(cfg.allowed("BL10C:VG-01:PRESSURE"));
assert!(!cfg.allowed("RFP:HV"));
}
#[test]
fn acl_deny_overrides_allow() {
let cfg = AclConfig {
allow_only: vec!["MOTOR:*".into()],
deny: vec!["MOTOR:JOG:*".into()],
};
assert!(cfg.allowed("MOTOR:VAL"));
assert!(!cfg.allowed("MOTOR:JOG:UP"));
assert!(!cfg.allowed("OTHER:PV"));
}
#[test]
fn audit_event_classifies_results() {
let denied = make_audit_event(
"MOTOR:VAL",
"alice",
"host1",
&Err("ACL: PV 'MOTOR:VAL' denied".into()),
);
assert_eq!(denied.result, AuditResult::Denied);
assert!(!denied.error.is_empty());
let read_only = make_audit_event(
"MOTOR:VAL",
"",
"",
&Err("read-only mode: PUT rejected".into()),
);
assert_eq!(read_only.result, AuditResult::Denied);
let failed = make_audit_event(
"MOTOR:VAL",
"alice",
"host1",
&Err("upstream timeout".into()),
);
assert_eq!(failed.result, AuditResult::Failed);
let ok = make_audit_event("MOTOR:VAL", "alice", "host1", &Ok(()));
assert_eq!(ok.result, AuditResult::Ok);
assert!(ok.error.is_empty());
}
}