use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use parking_lot::Mutex;
use tokio::sync::mpsc;
use epics_pva_rs::client::PvaClient;
use epics_pva_rs::pvdata::{FieldDesc, PvField};
use epics_pva_rs::server_native::source::{ChannelContext, ChannelSource};
use super::channel_cache::ChannelCache;
#[derive(Debug, Clone)]
pub struct RawEvent {
pub body: bytes::Bytes,
pub byte_order: epics_pva_rs::proto::ByteOrder,
}
#[derive(Clone)]
pub struct GatewayChannelSource {
cache: Arc<ChannelCache>,
pub connect_timeout: Duration,
pub subscriber_queue: usize,
pub rpc_timeout: Duration,
pub max_subscribers: usize,
subscriber_count: Arc<AtomicUsize>,
upstream_pool: Arc<Mutex<HashMap<(String, String), Arc<PvaClient>>>>,
}
impl GatewayChannelSource {
pub fn new(cache: Arc<ChannelCache>) -> Self {
Self {
cache,
connect_timeout: Duration::from_secs(5),
subscriber_queue: 64,
rpc_timeout: Duration::from_secs(30),
max_subscribers: 100_000,
subscriber_count: Arc::new(AtomicUsize::new(0)),
upstream_pool: Arc::new(Mutex::new(HashMap::new())),
}
}
fn upstream_client_for(&self, ctx: &ChannelContext) -> Arc<PvaClient> {
if ctx.account.is_empty() || ctx.method == "anonymous" {
return self.cache.client().clone();
}
let key = (ctx.account.clone(), ctx.method.clone());
let mut pool = self.upstream_pool.lock();
if let Some(c) = pool.get(&key) {
return c.clone();
}
let client = Arc::new(
PvaClient::builder()
.user(ctx.account.clone())
.host(ctx.host.clone())
.build(),
);
pool.insert(key, client.clone());
client
}
pub fn cache(&self) -> &Arc<ChannelCache> {
&self.cache
}
pub async fn cached_entry_count(&self) -> usize {
self.cache.entry_count().await
}
pub fn live_subscribers(&self) -> usize {
self.subscriber_count.load(Ordering::Relaxed)
}
}
impl ChannelSource for GatewayChannelSource {
async fn list_pvs(&self) -> Vec<String> {
self.cache.names().await
}
async fn has_pv(&self, name: &str) -> bool {
self.cache.lookup(name, self.connect_timeout).await.is_ok()
}
async fn get_introspection(&self, name: &str) -> Option<FieldDesc> {
let entry = self.cache.lookup(name, self.connect_timeout).await.ok()?;
entry.introspection()
}
async fn get_value(&self, name: &str) -> Option<PvField> {
let entry = self.cache.lookup(name, self.connect_timeout).await.ok()?;
entry.snapshot()
}
async fn put_value(&self, name: &str, value: PvField) -> Result<(), String> {
let _entry = self
.cache
.lookup(name, self.connect_timeout)
.await
.map_err(|e| e.to_string())?;
let value_str = pvfield_to_pvput_string(&value)
.ok_or_else(|| "unsupported PvField shape for upstream PUT".to_string())?;
self.cache
.client()
.pvput(name, &value_str)
.await
.map_err(|e| e.to_string())
}
async fn put_value_ctx(
&self,
name: &str,
value: PvField,
ctx: ChannelContext,
) -> Result<(), String> {
let _entry = self
.cache
.lookup(name, self.connect_timeout)
.await
.map_err(|e| e.to_string())?;
let value_str = pvfield_to_pvput_string(&value)
.ok_or_else(|| "unsupported PvField shape for upstream PUT".to_string())?;
let client = self.upstream_client_for(&ctx);
tracing::debug!(
pv = %name,
account = %ctx.account,
method = %ctx.method,
"pva-gateway: forwarding PUT with downstream credentials"
);
client
.pvput(name, &value_str)
.await
.map_err(|e| e.to_string())
}
async fn is_writable(&self, name: &str) -> bool {
self.cache.peek(name).await.is_some()
}
async fn rpc(
&self,
name: &str,
request_desc: FieldDesc,
request_value: PvField,
) -> Result<(FieldDesc, PvField), String> {
let _entry = self
.cache
.lookup(name, self.connect_timeout)
.await
.map_err(|e| e.to_string())?;
let result = tokio::time::timeout(
self.rpc_timeout,
self.cache
.client()
.pvrpc(name, &request_desc, &request_value),
)
.await;
match result {
Ok(Ok(pair)) => Ok(pair),
Ok(Err(e)) => Err(e.to_string()),
Err(_) => Err(format!("upstream rpc timeout for {name}")),
}
}
async fn subscribe_raw(
&self,
name: &str,
) -> Option<mpsc::Receiver<epics_pva_rs::server_native::RawMonitorEvent>> {
if let Some(v) = epics_base_rs::runtime::env::get("EPICS_PVA_GW_RAW_FRAMES") {
if v.eq_ignore_ascii_case("NO") || v.eq_ignore_ascii_case("FALSE") || v == "0" {
return None;
}
}
let entry = self.cache.lookup(name, self.connect_timeout).await.ok()?;
let mut bcast = entry.subscribe_raw();
let (mpsc_tx, mpsc_rx) =
mpsc::channel::<epics_pva_rs::server_native::RawMonitorEvent>(self.subscriber_queue);
let counter = self.subscriber_count.clone();
tokio::spawn(async move {
struct CounterGuard(std::sync::Arc<std::sync::atomic::AtomicUsize>);
impl Drop for CounterGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
}
}
let _guard = CounterGuard(counter);
loop {
match bcast.recv().await {
Ok(ev) => {
let out = epics_pva_rs::server_native::RawMonitorEvent {
body_bytes: ev.body,
byte_order: ev.byte_order,
};
if mpsc_tx.send(out).await.is_err() {
return;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => return,
}
}
});
Some(mpsc_rx)
}
async fn subscribe(&self, name: &str) -> Option<mpsc::Receiver<PvField>> {
let prev = self.subscriber_count.fetch_add(1, Ordering::Relaxed);
if prev >= self.max_subscribers {
self.subscriber_count.fetch_sub(1, Ordering::Relaxed);
tracing::warn!(
pv = %name,
live = prev,
cap = self.max_subscribers,
"pva-gateway: subscriber cap reached, refusing"
);
return None;
}
let entry = match self.cache.lookup(name, self.connect_timeout).await {
Ok(e) => e,
Err(_) => {
self.subscriber_count.fetch_sub(1, Ordering::Relaxed);
return None;
}
};
let mut bcast_rx = entry.subscribe();
let initial = entry.snapshot();
let (mpsc_tx, mpsc_rx) = mpsc::channel(self.subscriber_queue);
let counter = self.subscriber_count.clone();
tokio::spawn(async move {
struct CounterGuard(Arc<AtomicUsize>);
impl Drop for CounterGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
let _guard = CounterGuard(counter);
if let Some(v) = initial {
if mpsc_tx.send(v).await.is_err() {
return;
}
}
loop {
match bcast_rx.recv().await {
Ok(v) => {
if mpsc_tx.send(v).await.is_err() {
return;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => return,
}
}
});
Some(mpsc_rx)
}
fn notify_watermark_high(&self, name: &str) {
tracing::warn!(
pv = %name,
"pva-gateway: downstream monitor outbox crossed high watermark"
);
let cache = self.cache.clone();
let name_owned = name.to_string();
tokio::spawn(async move {
if let Some(entry) = cache.peek(&name_owned).await {
if let Some(p) = entry.pauser_snapshot() {
p.pause().await;
}
}
});
}
fn notify_watermark_low(&self, name: &str) {
tracing::debug!(
pv = %name,
"pva-gateway: downstream monitor outbox drained below low watermark"
);
let cache = self.cache.clone();
let name_owned = name.to_string();
tokio::spawn(async move {
if let Some(entry) = cache.peek(&name_owned).await {
if let Some(p) = entry.pauser_snapshot() {
p.resume().await;
}
}
});
}
}
fn pvfield_to_pvput_string(v: &PvField) -> Option<String> {
match v {
PvField::Scalar(sv) => Some(scalar_to_string(sv)),
PvField::ScalarArray(items) => {
let parts: Vec<String> = items.iter().map(scalar_to_string).collect();
Some(parts.join(" "))
}
PvField::Structure(s) => {
for (name, field) in &s.fields {
if name == "value" {
return pvfield_to_pvput_string(field);
}
}
None
}
PvField::Variant(boxed) => pvfield_to_pvput_string(&boxed.value),
PvField::Union {
selector, value, ..
} => {
if *selector < 0 {
None
} else {
pvfield_to_pvput_string(value)
}
}
_ => None,
}
}
fn scalar_to_string(sv: &epics_pva_rs::pvdata::ScalarValue) -> String {
use epics_pva_rs::pvdata::ScalarValue::*;
match sv {
Boolean(b) => {
if *b {
"1".into()
} else {
"0".into()
}
}
Byte(x) => x.to_string(),
UByte(x) => x.to_string(),
Short(x) => x.to_string(),
UShort(x) => x.to_string(),
Int(x) => x.to_string(),
UInt(x) => x.to_string(),
Long(x) => x.to_string(),
ULong(x) => x.to_string(),
Float(x) => x.to_string(),
Double(x) => x.to_string(),
String(s) => s.clone(),
}
}