use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use arc_swap::ArcSwap;
use epics_base_rs::server::database::{LinkSet, PvDatabase};
use epics_base_rs::server::snapshot::Snapshot;
use epics_base_rs::types::EpicsValue;
use epics_ca_rs::client::{CaChannel, CaClient};
use parking_lot::RwLock;
#[derive(Debug, thiserror::Error)]
pub enum CaLinkError {
#[error("CA client init failed: {0}")]
ClientInit(String),
#[error("CA link subscribe failed for {pv}: {reason}")]
Subscribe { pv: String, reason: String },
}
pub struct CaLink {
cache: Arc<ArcSwap<Option<Snapshot>>>,
connected: Arc<AtomicBool>,
channel: Arc<CaChannel>,
_monitor_task: AbortOnDrop,
_conn_task: AbortOnDrop,
}
struct AbortOnDrop(tokio::task::JoinHandle<()>);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
impl CaLink {
pub fn is_connected(&self) -> bool {
self.connected.load(Ordering::Acquire) && self.cache.load().as_ref().is_some()
}
pub fn value(&self) -> Option<EpicsValue> {
if !self.connected.load(Ordering::Acquire) {
return None;
}
self.cache.load().as_ref().as_ref().map(|s| s.value.clone())
}
pub fn alarm_severity(&self) -> Option<i32> {
if !self.connected.load(Ordering::Acquire) {
return None;
}
self.cache
.load()
.as_ref()
.as_ref()
.map(|s| s.alarm.severity as i32)
}
pub fn time_stamp(&self) -> Option<(i64, i32)> {
if !self.connected.load(Ordering::Acquire) {
return None;
}
let snap = self.cache.load();
let snap = snap.as_ref().as_ref()?;
let dur = snap.timestamp.duration_since(std::time::UNIX_EPOCH).ok()?;
Some((dur.as_secs() as i64, dur.subsec_nanos() as i32))
}
}
#[derive(Clone)]
pub struct CaLinkResolver {
client: Arc<CaClient>,
handle: tokio::runtime::Handle,
links: Arc<RwLock<HashMap<String, Arc<CaLink>>>>,
}
impl CaLinkResolver {
pub async fn new(handle: tokio::runtime::Handle) -> Result<Self, CaLinkError> {
let client = CaClient::new()
.await
.map_err(|e| CaLinkError::ClientInit(e.to_string()))?;
Ok(Self {
client: Arc::new(client),
handle,
links: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn with_client(client: Arc<CaClient>, handle: tokio::runtime::Handle) -> Self {
Self {
client,
handle,
links: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn open(&self, pv_name: &str) -> Result<Arc<CaLink>, CaLinkError> {
if let Some(existing) = self.links.read().get(pv_name).cloned() {
return Ok(existing);
}
let channel = Arc::new(self.client.create_channel(pv_name));
let monitor = channel
.subscribe()
.await
.map_err(|e| CaLinkError::Subscribe {
pv: pv_name.to_string(),
reason: e.to_string(),
})?;
let cache: Arc<ArcSwap<Option<Snapshot>>> = Arc::new(ArcSwap::from_pointee(None));
let connected = Arc::new(AtomicBool::new(false));
let conn_rx = channel.connection_events();
let conn_task = self
.handle
.spawn(run_connection_watcher(conn_rx, connected.clone()));
let task = self.handle.spawn(run_monitor(
monitor,
cache.clone(),
connected.clone(),
pv_name.to_string(),
));
let link = Arc::new(CaLink {
cache,
connected,
channel,
_monitor_task: AbortOnDrop(task),
_conn_task: AbortOnDrop(conn_task),
});
let mut links = self.links.write();
if let Some(existing) = links.get(pv_name).cloned() {
return Ok(existing);
}
links.insert(pv_name.to_string(), link.clone());
Ok(link)
}
pub async fn wait_for_link_connected(
&self,
pv_name: &str,
timeout: std::time::Duration,
) -> bool {
let name = strip_ca_scheme(pv_name);
let link = match self.open(name).await {
Ok(l) => l,
Err(_) => return false,
};
let deadline = std::time::Instant::now() + timeout;
loop {
if link.value().is_some() {
return true;
}
if std::time::Instant::now() >= deadline {
return false;
}
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
}
}
pub fn link_count(&self) -> usize {
self.links.read().len()
}
fn link_for(&self, name: &str) -> Option<Arc<CaLink>> {
if let Some(existing) = self.links.read().get(name).cloned() {
return Some(existing);
}
let resolver = self.clone();
let name = name.to_string();
block_in_place_or_warn(move || resolver.handle.block_on(resolver.open(&name)).ok())
}
}
async fn run_monitor(
mut monitor: epics_ca_rs::client::MonitorHandle,
cache: Arc<ArcSwap<Option<Snapshot>>>,
connected: Arc<AtomicBool>,
pv_name: String,
) {
while let Some(event) = monitor.recv().await {
match event {
Ok(snapshot) => {
connected.store(true, Ordering::Release);
cache.store(Arc::new(Some(snapshot)));
}
Err(e) => {
tracing::debug!(
pv = %pv_name,
error = %e,
"calink: monitor error event ignored, keeping last cached value"
);
}
}
}
connected.store(false, Ordering::Release);
}
async fn run_connection_watcher(
mut conn_rx: epics_base_rs::runtime::sync::broadcast::Receiver<
epics_ca_rs::client::ConnectionEvent,
>,
connected: Arc<AtomicBool>,
) {
use epics_ca_rs::client::ConnectionEvent;
loop {
match conn_rx.recv().await {
Ok(ConnectionEvent::Connected) => connected.store(true, Ordering::Release),
Ok(ConnectionEvent::Disconnected) | Ok(ConnectionEvent::Unresponsive) => {
connected.store(false, Ordering::Release)
}
Ok(_) => {}
Err(epics_base_rs::runtime::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(epics_base_rs::runtime::sync::broadcast::error::RecvError::Closed) => return,
}
}
}
impl LinkSet for CaLinkResolver {
fn is_connected(&self, name: &str) -> bool {
let name = strip_ca_scheme(name);
match self.links.read().get(name) {
Some(link) => link.is_connected(),
None => false,
}
}
fn get_value(&self, name: &str) -> Option<EpicsValue> {
let name = strip_ca_scheme(name);
self.link_for(name)?.value()
}
fn put_value(&self, name: &str, value: EpicsValue) -> Result<(), String> {
let name = strip_ca_scheme(name);
let link = self
.link_for(name)
.ok_or_else(|| format!("CA link {name} not open"))?;
let channel = link.channel.clone();
block_in_place_or_warn(move || {
self.handle
.block_on(async { channel.put(&value).await })
.map_err(|e| e.to_string())
})
}
fn alarm_severity(&self, name: &str) -> Option<i32> {
let name = strip_ca_scheme(name);
let sev = self.link_for(name)?.alarm_severity()?;
if sev > 0 { Some(sev) } else { None }
}
fn time_stamp(&self, name: &str) -> Option<(i64, i32)> {
let name = strip_ca_scheme(name);
self.link_for(name)?.time_stamp()
}
fn link_names(&self) -> Vec<String> {
self.links.read().keys().cloned().collect()
}
}
fn strip_ca_scheme(name: &str) -> &str {
name.strip_prefix("ca://").unwrap_or(name)
}
fn block_in_place_or_warn<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
use tokio::runtime::{Handle, RuntimeFlavor};
if let Ok(handle) = Handle::try_current() {
match handle.runtime_flavor() {
RuntimeFlavor::MultiThread => tokio::task::block_in_place(f),
_ => f(),
}
} else {
f()
}
}
pub async fn install_calink_resolver(
db: &PvDatabase,
handle: tokio::runtime::Handle,
) -> Result<CaLinkResolver, CaLinkError> {
let resolver = CaLinkResolver::new(handle).await?;
db.register_link_set("ca", Arc::new(resolver.clone())).await;
Ok(resolver)
}
#[cfg(test)]
mod tests {
use super::*;
use epics_ca_rs::client::ConnectionEvent;
#[test]
fn strip_ca_scheme_handles_both_forms() {
assert_eq!(strip_ca_scheme("ca://OTHER:PV"), "OTHER:PV");
assert_eq!(strip_ca_scheme("OTHER:PV"), "OTHER:PV");
}
async fn await_flag(flag: &AtomicBool, want: bool) {
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
while flag.load(Ordering::Acquire) != want {
assert!(
std::time::Instant::now() < deadline,
"connected flag never reached {want}"
);
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
}
#[tokio::test]
async fn bug1_connection_watcher_tracks_disconnect() {
let (tx, rx) = epics_base_rs::runtime::sync::broadcast::channel::<ConnectionEvent>(16);
let connected = Arc::new(AtomicBool::new(false));
let watcher = tokio::spawn(run_connection_watcher(rx, connected.clone()));
tx.send(ConnectionEvent::Connected).unwrap();
await_flag(&connected, true).await;
tx.send(ConnectionEvent::Disconnected).unwrap();
await_flag(&connected, false).await;
tx.send(ConnectionEvent::Connected).unwrap();
await_flag(&connected, true).await;
tx.send(ConnectionEvent::Unresponsive).unwrap();
await_flag(&connected, false).await;
drop(tx);
tokio::time::timeout(std::time::Duration::from_secs(2), watcher)
.await
.expect("watcher must exit when the event channel closes")
.expect("watcher task panicked");
}
#[test]
fn bug1_disconnected_link_serves_no_stale_value() {
let cache: Arc<ArcSwap<Option<Snapshot>>> = Arc::new(ArcSwap::from_pointee(None));
let connected = Arc::new(AtomicBool::new(false));
cache.store(Arc::new(Some(Snapshot::new(
EpicsValue::Double(42.0),
0,
0,
std::time::SystemTime::UNIX_EPOCH,
))));
let is_connected = connected.load(Ordering::Acquire) && cache.load().as_ref().is_some();
assert!(
!is_connected,
"a disconnected link must report not-connected even with a cached snapshot"
);
let value = if connected.load(Ordering::Acquire) {
cache.load().as_ref().as_ref().map(|s| s.value.clone())
} else {
None
};
assert!(
value.is_none(),
"a disconnected link must serve no stale value"
);
connected.store(true, Ordering::Release);
let is_connected = connected.load(Ordering::Acquire) && cache.load().as_ref().is_some();
assert!(
is_connected,
"reconnected link with cache must be connected"
);
}
}