use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::DbFieldType;
use crate::client::{CaChannel, CaClient};
use crate::protocol::{DBE_ALARM, DBE_VALUE};
use arc_swap::ArcSwap;
use epics_base_rs::server::database::{LinkDbfType, LinkMetadata, LinkPutOp, LinkSet, PvDatabase};
use epics_base_rs::server::snapshot::{DbrClass, Snapshot};
use epics_base_rs::types::EpicsValue;
use parking_lot::RwLock;
const CALINK_EVENT_MASK: u16 = DBE_VALUE | DBE_ALARM;
struct CachedSnapshot {
snapshot: Snapshot,
native_count: u32,
}
fn cache_native_matches(
cached_type: DbFieldType,
cached_count: u32,
current_type: DbFieldType,
current_count: u32,
) -> bool {
cached_type == current_type && cached_count == current_count
}
#[derive(Debug, thiserror::Error)]
pub enum CaLinkError {
#[error("CA client init failed: {0}")]
ClientInit(String),
#[error("CA link subscribe failed for {pv}: {reason}")]
Subscribe { pv: String, reason: String },
}
pub struct CaLink {
cache: Arc<ArcSwap<Option<CachedSnapshot>>>,
connected: Arc<AtomicBool>,
meta: Arc<ArcSwap<Option<LinkMetadata>>>,
channel: Arc<CaChannel>,
_monitor_task: AbortOnDrop,
_conn_task: AbortOnDrop,
}
struct AbortOnDrop(tokio::task::JoinHandle<()>);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
impl CaLink {
fn with_servable<R>(&self, f: impl FnOnce(&Snapshot) -> R) -> Option<R> {
if !self.connected.load(Ordering::Acquire) {
return None;
}
let guard = self.cache.load();
let cached = guard.as_ref().as_ref()?;
if !self.cache_matches_channel(cached) {
return None;
}
Some(f(&cached.snapshot))
}
fn cache_matches_channel(&self, cached: &CachedSnapshot) -> bool {
match (
self.channel.native_field_type(),
self.channel.element_count(),
) {
(Ok(cur_type), Ok(cur_count)) => cache_native_matches(
cached.snapshot.value.dbr_type(),
cached.native_count,
cur_type,
cur_count,
),
_ => false,
}
}
pub fn is_connected(&self) -> bool {
self.with_servable(|_| ()).is_some()
}
pub fn value(&self) -> Option<EpicsValue> {
self.with_servable(|s| s.value.clone())
}
pub fn alarm_severity(&self) -> Option<i32> {
self.with_servable(|s| s.alarm.severity as i32)
}
pub fn alarm_status(&self) -> Option<i32> {
self.with_servable(|s| s.alarm.status as i32)
}
pub fn time_stamp(&self) -> Option<(i64, i32, u64)> {
self.with_servable(|s| {
let dur = s.timestamp.since_unix_epoch();
Some((dur.as_secs() as i64, dur.subsec_nanos() as i32, 0))
})
.flatten()
}
pub fn link_metadata(&self) -> Option<LinkMetadata> {
if !self.connected.load(Ordering::Acquire) {
return None;
}
self.meta.load().as_ref().clone()
}
}
#[derive(Clone)]
pub struct CaLinkResolver {
client: Arc<tokio::sync::OnceCell<Arc<CaClient>>>,
handle: tokio::runtime::Handle,
links: Arc<RwLock<HashMap<String, Arc<CaLink>>>>,
db: Arc<RwLock<Option<PvDatabase>>>,
}
impl CaLinkResolver {
pub fn new(handle: tokio::runtime::Handle) -> Self {
Self {
client: Arc::new(tokio::sync::OnceCell::new()),
handle,
links: Arc::new(RwLock::new(HashMap::new())),
db: Arc::new(RwLock::new(None)),
}
}
pub fn with_client(client: Arc<CaClient>, handle: tokio::runtime::Handle) -> Self {
Self {
client: Arc::new(tokio::sync::OnceCell::new_with(Some(client))),
handle,
links: Arc::new(RwLock::new(HashMap::new())),
db: Arc::new(RwLock::new(None)),
}
}
async fn client(&self) -> Result<&Arc<CaClient>, CaLinkError> {
self.client
.get_or_try_init(|| async {
CaClient::new()
.await
.map(Arc::new)
.map_err(|e| CaLinkError::ClientInit(e.to_string()))
})
.await
}
pub fn attach_database(&self, db: PvDatabase) {
*self.db.write() = Some(db);
}
pub async fn open(&self, pv_name: &str) -> Result<Arc<CaLink>, CaLinkError> {
if let Some(existing) = self.links.read().get(pv_name).cloned() {
return Ok(existing);
}
let channel = Arc::new(self.client().await?.create_channel(pv_name));
let conn_rx = channel.connection_events();
let monitor = channel
.subscribe_with_mask(0.0, CALINK_EVENT_MASK)
.await
.map_err(|e| CaLinkError::Subscribe {
pv: pv_name.to_string(),
reason: e.to_string(),
})?;
let cache: Arc<ArcSwap<Option<CachedSnapshot>>> = Arc::new(ArcSwap::from_pointee(None));
let connected = Arc::new(AtomicBool::new(false));
let meta: Arc<ArcSwap<Option<LinkMetadata>>> = Arc::new(ArcSwap::from_pointee(None));
let conn_task = self.handle.spawn(run_connection_watcher(
conn_rx,
connected.clone(),
channel.clone(),
self.handle.clone(),
meta.clone(),
pv_name.to_string(),
));
let task = self.handle.spawn(run_monitor(
monitor,
cache.clone(),
connected.clone(),
channel.clone(),
pv_name.to_string(),
self.db.clone(),
));
let link = Arc::new(CaLink {
cache,
connected,
meta,
channel,
_monitor_task: AbortOnDrop(task),
_conn_task: AbortOnDrop(conn_task),
});
let mut links = self.links.write();
if let Some(existing) = links.get(pv_name).cloned() {
return Ok(existing);
}
links.insert(pv_name.to_string(), link.clone());
Ok(link)
}
pub async fn wait_for_link_connected(
&self,
pv_name: &str,
timeout: std::time::Duration,
) -> bool {
let name = strip_ca_scheme(pv_name);
let link = match self.open(name).await {
Ok(l) => l,
Err(_) => return false,
};
let deadline = std::time::Instant::now() + timeout;
loop {
if link.value().is_some() {
return true;
}
if std::time::Instant::now() >= deadline {
return false;
}
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
}
}
pub fn link_count(&self) -> usize {
self.links.read().len()
}
fn link_for(&self, name: &str) -> Option<Arc<CaLink>> {
if let Some(existing) = self.links.read().get(name).cloned() {
return Some(existing);
}
let resolver = self.clone();
let name = name.to_string();
block_in_place_or_warn(move || resolver.handle.block_on(resolver.open(&name)).ok())
}
}
async fn run_monitor(
mut monitor: crate::client::MonitorHandle,
cache: Arc<ArcSwap<Option<CachedSnapshot>>>,
connected: Arc<AtomicBool>,
channel: Arc<CaChannel>,
pv_name: String,
db: Arc<RwLock<Option<PvDatabase>>>,
) {
while let Some(event) = monitor.recv().await {
match event {
Ok(snapshot) => {
connected.store(true, Ordering::Release);
let native_count = channel
.element_count()
.unwrap_or_else(|_| snapshot.value.count());
cache.store(Arc::new(Some(CachedSnapshot {
snapshot,
native_count,
})));
let db_handle = db.read().clone();
if let Some(db_handle) = db_handle {
db_handle.dispatch_external_cp_targets(&pv_name).await;
}
}
Err(e) => {
tracing::debug!(
pv = %pv_name,
error = %e,
"calink: monitor error event ignored, keeping last cached value"
);
}
}
}
connected.store(false, Ordering::Release);
}
async fn run_connection_watcher(
mut conn_rx: epics_base_rs::runtime::sync::broadcast::Receiver<crate::client::ConnectionEvent>,
connected: Arc<AtomicBool>,
channel: Arc<CaChannel>,
handle: tokio::runtime::Handle,
meta: Arc<ArcSwap<Option<LinkMetadata>>>,
pv_name: String,
) {
loop {
match conn_rx.recv().await {
Ok(evt) => {
if note_conn_event(&evt, &connected) {
handle.spawn(fetch_link_metadata(
channel.clone(),
meta.clone(),
pv_name.clone(),
));
}
}
Err(epics_base_rs::runtime::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(epics_base_rs::runtime::sync::broadcast::error::RecvError::Closed) => return,
}
}
}
fn note_conn_event(evt: &crate::client::ConnectionEvent, flag: &AtomicBool) -> bool {
use crate::client::ConnectionEvent;
match evt {
ConnectionEvent::Connected => {
flag.store(true, Ordering::Release);
true
}
ConnectionEvent::Disconnected | ConnectionEvent::Unresponsive => {
flag.store(false, Ordering::Release);
false
}
_ => false,
}
}
async fn fetch_link_metadata(
channel: Arc<CaChannel>,
meta: Arc<ArcSwap<Option<LinkMetadata>>>,
pv_name: String,
) {
let (dbf, element_count) = match channel.info().await {
Ok(info) => (Some(info.native_type), Some(info.element_count)),
Err(_) => (None, None),
};
let ctrl = match channel.get_with_metadata_count(DbrClass::Ctrl, 1).await {
Ok(snap) => Some(snap),
Err(e) => {
tracing::debug!(
pv = %pv_name,
error = %e,
"calink: CTRL attribute get failed; serving DBF type / element count only"
);
None
}
};
meta.store(Arc::new(Some(build_link_metadata(
dbf,
element_count,
ctrl.as_ref(),
))));
}
fn build_link_metadata(
dbf: Option<DbFieldType>,
element_count: Option<u32>,
ctrl: Option<&Snapshot>,
) -> LinkMetadata {
let mut md = LinkMetadata {
dbf_type: dbf.map(map_dbf_type),
element_count: element_count.map(|n| n as i64),
..LinkMetadata::default()
};
if let Some(snap) = ctrl {
if let Some(d) = snap.display.as_ref() {
md.graphic_limits = Some((d.lower_disp_limit, d.upper_disp_limit));
md.alarm_limits = Some((
d.lower_alarm_limit,
d.lower_warning_limit,
d.upper_warning_limit,
d.upper_alarm_limit,
));
md.precision = Some(d.precision);
if !d.units.is_empty() {
md.units = Some(d.units.as_str_lossy().into_owned());
}
if !d.description.is_empty() {
md.description = Some(d.description.clone());
}
}
if let Some(c) = snap.control.as_ref() {
md.control_limits = Some((c.lower_ctrl_limit, c.upper_ctrl_limit));
}
}
md
}
fn map_dbf_type(t: DbFieldType) -> LinkDbfType {
match t {
DbFieldType::String => LinkDbfType::String,
DbFieldType::Short => LinkDbfType::Short,
DbFieldType::Float => LinkDbfType::Float,
DbFieldType::Enum => LinkDbfType::Enum,
DbFieldType::Char => LinkDbfType::Char,
DbFieldType::Long => LinkDbfType::Long,
DbFieldType::Double => LinkDbfType::Double,
DbFieldType::Int64 => LinkDbfType::Int64,
DbFieldType::UInt64 => LinkDbfType::UInt64,
DbFieldType::UShort => LinkDbfType::UShort,
DbFieldType::ULong => LinkDbfType::ULong,
}
}
impl LinkSet for CaLinkResolver {
fn is_connected(&self, name: &str) -> bool {
let name = strip_ca_scheme(name);
match self.links.read().get(name) {
Some(link) => link.is_connected(),
None => false,
}
}
fn get_value(&self, name: &str) -> Option<EpicsValue> {
let name = strip_ca_scheme(name);
self.link_for(name)?.value()
}
fn put_value(&self, name: &str, value: EpicsValue, op: LinkPutOp) -> Result<(), String> {
let name = strip_ca_scheme(name);
let link = self
.link_for(name)
.ok_or_else(|| format!("CA link {name} not open"))?;
let channel = link.channel.clone();
block_in_place_or_warn(move || {
self.handle
.block_on(async move {
match op {
LinkPutOp::Plain => channel.put_nowait(&value).await,
LinkPutOp::Async => channel.put(&value).await,
}
})
.map_err(|e| e.to_string())
})
}
fn alarm_severity(&self, name: &str) -> Option<i32> {
let name = strip_ca_scheme(name);
let sev = self.link_for(name)?.alarm_severity()?;
if sev > 0 { Some(sev) } else { None }
}
fn alarm_status(&self, name: &str) -> Option<i32> {
let name = strip_ca_scheme(name);
self.link_for(name)?.alarm_status()
}
fn time_stamp(&self, name: &str) -> Option<(i64, i32, u64)> {
let name = strip_ca_scheme(name);
self.link_for(name)?.time_stamp()
}
fn link_metadata(&self, name: &str) -> Option<LinkMetadata> {
let name = strip_ca_scheme(name);
self.link_for(name)?.link_metadata()
}
fn link_names(&self) -> Vec<String> {
self.links.read().keys().cloned().collect()
}
}
fn strip_ca_scheme(name: &str) -> &str {
name.strip_prefix("ca://").unwrap_or(name)
}
fn block_in_place_or_warn<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
use tokio::runtime::{Handle, RuntimeFlavor};
if let Ok(handle) = Handle::try_current() {
match handle.runtime_flavor() {
RuntimeFlavor::MultiThread => tokio::task::block_in_place(f),
_ => f(),
}
} else {
f()
}
}
pub async fn install_calink_resolver(
db: &PvDatabase,
handle: tokio::runtime::Handle,
) -> CaLinkResolver {
let resolver = CaLinkResolver::new(handle);
resolver.attach_database(db.clone());
db.register_link_set("ca", Arc::new(resolver.clone())).await;
resolver
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::ConnectionEvent;
#[test]
fn strip_ca_scheme_handles_both_forms() {
assert_eq!(strip_ca_scheme("ca://OTHER:PV"), "OTHER:PV");
assert_eq!(strip_ca_scheme("OTHER:PV"), "OTHER:PV");
}
#[test]
fn calink_monitor_mask_is_dbca_value_alarm_without_log() {
use crate::protocol::{DBE_ALARM, DBE_LOG, DBE_VALUE};
assert_eq!(
CALINK_EVENT_MASK,
DBE_VALUE | DBE_ALARM,
"calink monitor mask must equal dbCa's DBE_VALUE | DBE_ALARM"
);
assert_eq!(
CALINK_EVENT_MASK & DBE_LOG,
0,
"calink must not request DBE_LOG — dbCa.c never does"
);
}
#[test]
fn bug1_connection_event_tracks_disconnect() {
let connected = AtomicBool::new(false);
assert!(note_conn_event(&ConnectionEvent::Connected, &connected));
assert!(connected.load(Ordering::Acquire));
assert!(!note_conn_event(&ConnectionEvent::Disconnected, &connected));
assert!(!connected.load(Ordering::Acquire));
assert!(note_conn_event(&ConnectionEvent::Connected, &connected));
assert!(connected.load(Ordering::Acquire));
assert!(!note_conn_event(&ConnectionEvent::Unresponsive, &connected));
assert!(!connected.load(Ordering::Acquire));
connected.store(true, Ordering::Release);
assert!(!note_conn_event(
&ConnectionEvent::AccessRightsChanged {
read: true,
write: false,
},
&connected,
));
assert!(connected.load(Ordering::Acquire));
}
#[test]
fn bug1_disconnected_link_serves_no_stale_value() {
let cache: Arc<ArcSwap<Option<Snapshot>>> = Arc::new(ArcSwap::from_pointee(None));
let connected = Arc::new(AtomicBool::new(false));
cache.store(Arc::new(Some(Snapshot::new(
EpicsValue::Double(42.0),
0,
0,
std::time::SystemTime::UNIX_EPOCH,
))));
let is_connected = connected.load(Ordering::Acquire) && cache.load().as_ref().is_some();
assert!(
!is_connected,
"a disconnected link must report not-connected even with a cached snapshot"
);
let value = if connected.load(Ordering::Acquire) {
cache.load().as_ref().as_ref().map(|s| s.value.clone())
} else {
None
};
assert!(
value.is_none(),
"a disconnected link must serve no stale value"
);
connected.store(true, Ordering::Release);
let is_connected = connected.load(Ordering::Acquire) && cache.load().as_ref().is_some();
assert!(
is_connected,
"reconnected link with cache must be connected"
);
}
#[test]
fn calink_cache_invalidated_on_native_type_or_count_change() {
assert!(
cache_native_matches(DbFieldType::Double, 1, DbFieldType::Double, 1),
"matching scalar type+count stays servable"
);
assert!(
cache_native_matches(DbFieldType::Short, 10, DbFieldType::Short, 10),
"matching waveform type+count stays servable"
);
assert!(
!cache_native_matches(DbFieldType::Short, 1, DbFieldType::Double, 1),
"a DBR-type change invalidates the old cache"
);
assert!(
!cache_native_matches(DbFieldType::Short, 10, DbFieldType::Short, 5),
"an element-count change invalidates the old cache"
);
}
fn ctrl_snapshot() -> Snapshot {
use epics_base_rs::server::snapshot::{ControlInfo, DisplayInfo};
let mut snap = Snapshot::new(
EpicsValue::Double(0.0),
0,
0,
std::time::SystemTime::UNIX_EPOCH,
);
snap.display = Some(DisplayInfo {
units: "degC".into(),
precision: 3,
upper_disp_limit: 100.0,
lower_disp_limit: -50.0,
upper_alarm_limit: 90.0, upper_warning_limit: 80.0, lower_warning_limit: -20.0, lower_alarm_limit: -40.0, ..Default::default()
});
snap.control = Some(ControlInfo {
upper_ctrl_limit: 95.0,
lower_ctrl_limit: -45.0,
});
snap
}
#[test]
fn build_link_metadata_numeric_maps_all_fields() {
let snap = ctrl_snapshot();
let md = build_link_metadata(Some(DbFieldType::Double), Some(1), Some(&snap));
assert_eq!(md.dbf_type, Some(LinkDbfType::Double));
assert_eq!(md.element_count, Some(1));
assert_eq!(md.graphic_limits, Some((-50.0, 100.0)));
assert_eq!(md.control_limits, Some((-45.0, 95.0)));
assert_eq!(md.alarm_limits, Some((-40.0, -20.0, 80.0, 90.0)));
assert_eq!(md.precision, Some(3));
assert_eq!(md.units.as_deref(), Some("degC"));
}
#[test]
fn build_link_metadata_string_pv_has_no_limits() {
let snap = Snapshot::new(
EpicsValue::String("x".into()),
0,
0,
std::time::SystemTime::UNIX_EPOCH,
);
let md = build_link_metadata(Some(DbFieldType::String), Some(1), Some(&snap));
assert_eq!(md.dbf_type, Some(LinkDbfType::String));
assert_eq!(md.element_count, Some(1));
assert_eq!(md.graphic_limits, None);
assert_eq!(md.control_limits, None);
assert_eq!(md.alarm_limits, None);
assert_eq!(md.precision, None);
assert_eq!(md.units, None);
}
#[test]
fn build_link_metadata_no_info_no_ctrl_is_all_none() {
let md = build_link_metadata(None, None, None);
assert_eq!(md, LinkMetadata::default());
}
#[test]
fn build_link_metadata_empty_units_omitted() {
use epics_base_rs::server::snapshot::DisplayInfo;
let mut snap = Snapshot::new(
EpicsValue::Double(0.0),
0,
0,
std::time::SystemTime::UNIX_EPOCH,
);
snap.display = Some(DisplayInfo {
units: "".into(),
precision: 0,
..Default::default()
});
let md = build_link_metadata(Some(DbFieldType::Double), Some(1), Some(&snap));
assert_eq!(md.units, None);
assert_eq!(md.precision, Some(0));
}
#[test]
fn map_dbf_type_covers_every_variant() {
assert_eq!(map_dbf_type(DbFieldType::String), LinkDbfType::String);
assert_eq!(map_dbf_type(DbFieldType::Short), LinkDbfType::Short);
assert_eq!(map_dbf_type(DbFieldType::Float), LinkDbfType::Float);
assert_eq!(map_dbf_type(DbFieldType::Enum), LinkDbfType::Enum);
assert_eq!(map_dbf_type(DbFieldType::Char), LinkDbfType::Char);
assert_eq!(map_dbf_type(DbFieldType::Long), LinkDbfType::Long);
assert_eq!(map_dbf_type(DbFieldType::Double), LinkDbfType::Double);
assert_eq!(map_dbf_type(DbFieldType::Int64), LinkDbfType::Int64);
assert_eq!(map_dbf_type(DbFieldType::UInt64), LinkDbfType::UInt64);
}
}