mod beacon_monitor;
mod circuit_breaker;
mod search;
mod state;
mod subscription;
mod sync_group;
mod transport;
mod types;
pub use sync_group::{SyncGroup, SyncGroupResults, SyncGroupStat, SyncGroupStatus};
pub use circuit_breaker::{BreakerConfig, BreakerState};
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::time::Duration;
use epics_base_rs::runtime::sync::{broadcast, mpsc, oneshot};
use parking_lot::Mutex;
use crate::channel::{AccessRights, ChannelInfo};
use crate::protocol::*;
use crate::repeater;
use epics_base_rs::error::{CaError, CaResult};
use epics_base_rs::server::snapshot::{DbrClass, Snapshot};
use epics_base_rs::types::{
DBR_STRING, DBR_TIME_INT, DBR_TIME_STRING, DbFieldType, EpicsValue, PvString, decode_dbr,
};
pub use state::{ChannelState, ConnectionEvent};
pub fn enum_string_readback_dbr(
native: DbFieldType,
want_time: bool,
enum_as_string: bool,
) -> Option<u16> {
if enum_as_string && matches!(native, DbFieldType::Enum) {
Some(if want_time {
DBR_TIME_STRING
} else {
DBR_STRING
})
} else {
None
}
}
pub fn float_as_string_readback_dbr(native: DbFieldType) -> Option<u16> {
matches!(native, DbFieldType::Float | DbFieldType::Double).then_some(DBR_TIME_STRING)
}
pub fn enum_cli_readback_dbr(native: DbFieldType, enum_as_number: bool) -> Option<u16> {
matches!(native, DbFieldType::Enum).then(|| {
if enum_as_number {
DBR_TIME_INT
} else {
DBR_TIME_STRING
}
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnumReadback {
Native,
Label,
Numeric,
}
impl EnumReadback {
fn enum_substitution(self, native: DbFieldType) -> Option<u16> {
match self {
EnumReadback::Native => None,
EnumReadback::Label => enum_cli_readback_dbr(native, false),
EnumReadback::Numeric => enum_cli_readback_dbr(native, true),
}
}
}
pub fn subscription_readback_dbr(
native: DbFieldType,
enum_readback: EnumReadback,
float_as_string: bool,
) -> u16 {
enum_readback
.enum_substitution(native)
.or_else(|| {
float_as_string
.then(|| float_as_string_readback_dbr(native))
.flatten()
})
.unwrap_or_else(|| native.time_dbr_type())
}
use state::ChannelInner;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone)]
pub enum DiagEvent {
Connected {
pv: String,
server: SocketAddr,
},
Disconnected {
server: SocketAddr,
channels: usize,
},
Reconnected {
pv: String,
restored: u32,
stale: u32,
},
Unresponsive {
server: SocketAddr,
},
Responsive {
server: SocketAddr,
},
BeaconAnomaly {
server: SocketAddr,
},
}
impl std::fmt::Display for DiagEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connected { pv, server } => write!(f, "Connected {pv} @ {server}"),
Self::Disconnected { server, channels } => {
write!(f, "Disconnected {server} ({channels} channels)")
}
Self::Reconnected {
pv,
restored,
stale,
} => write!(f, "Reconnected {pv} (restored={restored}, stale={stale})"),
Self::Unresponsive { server } => write!(f, "Unresponsive {server}"),
Self::Responsive { server } => write!(f, "Responsive {server}"),
Self::BeaconAnomaly { server } => write!(f, "Beacon anomaly {server}"),
}
}
}
#[derive(Debug, Clone)]
pub struct DiagRecord {
pub time: std::time::Instant,
pub event: DiagEvent,
}
const EVENT_HISTORY_CAPACITY: usize = 256;
const ONE_SHOT_CHANNEL_CACHE_CAPACITY: usize = 4096;
#[derive(Default)]
struct OneShotChannelCache {
channels: HashMap<String, CaChannel>,
order: VecDeque<String>,
}
impl OneShotChannelCache {
fn get_or_create(&mut self, client: &CaClient, pv_name: String) -> CaChannel {
if let Some(channel) = self.channels.get(&pv_name) {
return channel.clone();
}
let channel = client.create_channel_expanded(pv_name.clone(), 0);
self.channels.insert(pv_name.clone(), channel.clone());
self.order.push_back(pv_name);
while self.channels.len() > ONE_SHOT_CHANNEL_CACHE_CAPACITY {
let Some(oldest) = self.order.pop_front() else {
break;
};
self.channels.remove(&oldest);
}
channel
}
}
pub struct CaDiagnostics {
pub connections: AtomicU64,
pub disconnections: AtomicU64,
pub reconnections: AtomicU64,
pub unresponsive_events: AtomicU64,
pub subscriptions_restored: AtomicU64,
pub subscriptions_stale: AtomicU64,
pub beacon_anomalies: AtomicU64,
pub search_requests: AtomicU64,
pub dropped_monitors: AtomicU64,
history: std::sync::Mutex<Vec<DiagRecord>>,
}
impl Default for CaDiagnostics {
fn default() -> Self {
Self {
connections: AtomicU64::new(0),
disconnections: AtomicU64::new(0),
reconnections: AtomicU64::new(0),
unresponsive_events: AtomicU64::new(0),
subscriptions_restored: AtomicU64::new(0),
subscriptions_stale: AtomicU64::new(0),
beacon_anomalies: AtomicU64::new(0),
search_requests: AtomicU64::new(0),
dropped_monitors: AtomicU64::new(0),
history: std::sync::Mutex::new(Vec::with_capacity(EVENT_HISTORY_CAPACITY)),
}
}
}
impl CaDiagnostics {
pub fn record(&self, event: DiagEvent) {
let record = DiagRecord {
time: std::time::Instant::now(),
event,
};
if let Ok(mut history) = self.history.lock() {
if history.len() >= EVENT_HISTORY_CAPACITY {
history.remove(0);
}
history.push(record);
}
}
pub fn snapshot(&self) -> DiagnosticsSnapshot {
let history = self.history.lock().map(|h| h.clone()).unwrap_or_default();
DiagnosticsSnapshot {
connections: self.connections.load(Ordering::Relaxed),
disconnections: self.disconnections.load(Ordering::Relaxed),
reconnections: self.reconnections.load(Ordering::Relaxed),
unresponsive_events: self.unresponsive_events.load(Ordering::Relaxed),
subscriptions_restored: self.subscriptions_restored.load(Ordering::Relaxed),
subscriptions_stale: self.subscriptions_stale.load(Ordering::Relaxed),
beacon_anomalies: self.beacon_anomalies.load(Ordering::Relaxed),
search_requests: self.search_requests.load(Ordering::Relaxed),
dropped_monitors: self.dropped_monitors.load(Ordering::Relaxed),
history,
}
}
}
#[derive(Debug, Clone)]
pub struct DiagnosticsSnapshot {
pub connections: u64,
pub disconnections: u64,
pub reconnections: u64,
pub unresponsive_events: u64,
pub subscriptions_restored: u64,
pub subscriptions_stale: u64,
pub beacon_anomalies: u64,
pub search_requests: u64,
pub dropped_monitors: u64,
pub history: Vec<DiagRecord>,
}
impl std::fmt::Display for DiagnosticsSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Connections: {}", self.connections)?;
writeln!(f, "Disconnections: {}", self.disconnections)?;
writeln!(f, "Reconnections: {}", self.reconnections)?;
writeln!(f, "Unresponsive events: {}", self.unresponsive_events)?;
writeln!(f, "Subscriptions restored: {}", self.subscriptions_restored)?;
writeln!(f, "Subscriptions stale: {}", self.subscriptions_stale)?;
writeln!(f, "Beacon anomalies: {}", self.beacon_anomalies)?;
writeln!(f, "Search requests: {}", self.search_requests)?;
writeln!(f, "Dropped monitors: {}", self.dropped_monitors)?;
if !self.history.is_empty() {
writeln!(f, "Recent events ({}):", self.history.len())?;
let start = self
.history
.first()
.map(|r| r.time)
.unwrap_or_else(std::time::Instant::now);
for rec in &self.history {
let elapsed = rec.time.duration_since(start);
writeln!(f, " +{:.1}s {}", elapsed.as_secs_f64(), rec.event)?;
}
}
Ok(())
}
}
use subscription::SubscriptionRegistry;
use types::*;
pub use types::{CaException, CaExceptionHandler, CaExceptionKind};
pub struct CaClient {
search_tx: mpsc::UnboundedSender<SearchRequest>,
transport_tx: mpsc::UnboundedSender<TransportCommand>,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
one_shot_channels: Mutex<OneShotChannelCache>,
in_flight: InFlightOps,
cid_alloc: CidAllocator,
snapshots: ChannelSnapshots,
server_writers: DirectServerWriters,
diagnostics: Arc<CaDiagnostics>,
search_attempts: types::SearchAttempts,
exception_slot: types::CaExceptionSlot,
client_identity: types::ClientIdentitySlot,
_coordinator: tokio::task::JoinHandle<()>,
_search_task: tokio::task::JoinHandle<()>,
_transport_task: tokio::task::JoinHandle<()>,
_beacon_task: tokio::task::JoinHandle<()>,
_discovery_backends: Vec<Box<dyn crate::discovery::Backend>>,
_discovery_forwarders: Vec<tokio::task::JoinHandle<()>>,
}
#[allow(dead_code)]
enum CoordRequest {
RegisterChannel {
cid: u32,
pv_name: String,
priority: u8,
conn_tx: broadcast::Sender<ConnectionEvent>,
},
WaitConnected {
cid: u32,
reply: oneshot::Sender<()>,
},
Subscribe {
cid: u32,
mask: u16,
deadband: f64,
callback_tx: mpsc::Sender<CaResult<Snapshot>>,
coalesce_slot: std::sync::Arc<subscription::CoalesceSlot>,
req_count: Option<u32>,
enum_readback: EnumReadback,
float_as_string: bool,
reply: oneshot::Sender<CaResult<u32>>,
},
Unsubscribe {
subid: u32,
},
MonitorConsumed {
subid: u32,
},
DropChannel {
cid: u32,
},
ForceRescanServer {
server_addr: SocketAddr,
kind: beacon_monitor::BeaconAnomalyKind,
},
BeaconArrival {
server_addr: SocketAddr,
anomaly: bool,
},
GetWatchdogDelay {
cid: u32,
reply: oneshot::Sender<Option<Duration>>,
},
GetIocConnectionCount {
reply: oneshot::Sender<usize>,
},
GetHostMinorProtocol {
cid: u32,
reply: oneshot::Sender<Option<u16>>,
},
Shutdown {
reply: oneshot::Sender<()>,
},
}
#[derive(Default)]
pub struct CaClientConfig {
#[cfg(feature = "experimental-rust-tls")]
pub tls: Option<crate::tls::TlsConfig>,
#[cfg(feature = "experimental-rust-tls")]
pub tls_server_name: Option<String>,
pub discovery: Option<crate::discovery::DiscoveryConfig>,
pub extra_backends: Vec<Box<dyn crate::discovery::Backend>>,
}
impl CaClient {
pub async fn new() -> CaResult<Self> {
#[cfg(feature = "experimental-rust-tls")]
let cfg = {
let mut c = CaClientConfig::default();
match crate::tls::client_from_env() {
Ok(Some(tls)) => c.tls = Some(tls),
Ok(None) => {}
Err(e) => {
tracing::error!(error = %e,
"EPICS_CA_TLS_* configuration is invalid; using plaintext");
}
}
c.tls_server_name = epics_base_rs::runtime::env::get("EPICS_CA_TLS_SERVER_NAME");
c
};
#[cfg(not(feature = "experimental-rust-tls"))]
let cfg = CaClientConfig::default();
Self::new_with_config(cfg).await
}
pub async fn new_with_config(config: CaClientConfig) -> CaResult<Self> {
#[cfg(feature = "experimental-rust-tls")]
if config.tls.is_some() {
tracing::warn!(
"═══════════════════════════════════════════════════════════════════════\n \
CA client TLS ENABLED — non-standard, Rust-only extension.\n \
Cannot connect to C softIoc, EDM, MEDM, CSS, or pyepics-based tools.\n \
See doc/11-tls-design.md for rationale.\n \
═══════════════════════════════════════════════════════════════════════"
);
}
epics_base_rs::runtime::task::spawn(async { repeater::ensure_repeater().await });
let mut addr_list = parse_addr_list_with_hostnames()?;
let discovery_cfg = config.discovery.clone().or_else(crate::discovery::from_env);
let mut backends: Vec<Box<dyn crate::discovery::Backend>> = match discovery_cfg {
Some(cfg) => crate::discovery::build_backends(cfg),
None => Vec::new(),
};
backends.extend(config.extra_backends);
if !backends.is_empty() {
let mut discovered: Vec<SocketAddr> = Vec::new();
for b in &backends {
for addr in b.discover().await {
if !discovered.contains(&addr) {
discovered.push(addr);
}
}
}
if !discovered.is_empty() {
tracing::info!(
count = discovered.len(),
"discovered IOCs via service discovery: {:?}",
discovered
);
}
for addr in discovered {
if !addr_list.iter().any(|e| e.sock == addr) {
let port = match addr {
SocketAddr::V4(a) => a.port(),
SocketAddr::V6(a) => a.port(),
};
addr_list.push(AddrEntry::new(addr, None, port));
}
}
}
let nameserver_entries = parse_nameserver_list();
#[cfg(feature = "experimental-rust-tls")]
let sni_overrides: std::collections::HashMap<SocketAddr, String> = {
let mut map: std::collections::HashMap<SocketAddr, String> = nameserver_entries
.iter()
.filter_map(|(addr, host)| host.clone().map(|h| (*addr, h)))
.collect();
for (addr, host) in parse_tls_sni_map() {
map.insert(addr, host);
}
map
};
let nameserver_addrs: Vec<SocketAddr> =
nameserver_entries.iter().map(|(a, _)| *a).collect();
let (search_tx, search_rx) = mpsc::unbounded_channel();
let (search_resp_tx, search_resp_rx) = mpsc::unbounded_channel();
let (transport_tx, transport_rx) = mpsc::unbounded_channel();
let (transport_evt_tx, transport_evt_rx) = mpsc::unbounded_channel();
let (coord_tx, coord_rx) = mpsc::unbounded_channel();
let search_attempts: types::SearchAttempts = Arc::new(dashmap::DashMap::new());
let search_task = epics_base_rs::runtime::task::spawn(search::run_search_engine(
addr_list,
nameserver_addrs,
search_rx,
search_resp_tx,
search_attempts.clone(),
));
let mut discovery_forwarders: Vec<tokio::task::JoinHandle<()>> = Vec::new();
for backend in &backends {
if let Some(mut rx) = backend.subscribe() {
let fwd_search_tx = search_tx.clone();
discovery_forwarders.push(epics_base_rs::runtime::task::spawn(async move {
let mut addr_refs: std::collections::HashMap<SocketAddr, usize> =
std::collections::HashMap::new();
while let Some(evt) = rx.recv().await {
match evt {
crate::discovery::DiscoveryEvent::Added { addr, .. } => {
let n = addr_refs.entry(addr).or_insert(0);
*n += 1;
if *n == 1
&& fwd_search_tx.send(SearchRequest::AddAddress(addr)).is_err()
{
break;
}
}
crate::discovery::DiscoveryEvent::Removed { addr, .. } => {
if let Some(n) = addr_refs.get_mut(&addr) {
*n -= 1;
if *n == 0 {
addr_refs.remove(&addr);
if fwd_search_tx
.send(SearchRequest::RemoveAddress(addr))
.is_err()
{
break;
}
}
} else {
tracing::debug!(
%addr,
"discovery Removed for untracked address — ignored"
);
}
}
}
}
}));
}
}
#[cfg(feature = "experimental-rust-tls")]
let tls_arc = config.tls.as_ref().and_then(|t| match t {
crate::tls::TlsConfig::Client(arc) => Some(arc.clone()),
crate::tls::TlsConfig::Server(_) => {
tracing::warn!("server-side TlsConfig passed to CaClient; ignoring");
None
}
});
let in_flight = InFlightOps::new();
let cid_alloc = CidAllocator::new();
let snapshots: ChannelSnapshots = Arc::new(dashmap::DashMap::new());
let server_writers: DirectServerWriters = Arc::new(dashmap::DashMap::new());
let last_rx_at: ServerLastRxAt = Arc::new(dashmap::DashMap::new());
let client_identity: types::ClientIdentitySlot =
Arc::new(parking_lot::RwLock::new(types::ClientIdentity::from_env()));
let transport_task = {
#[cfg(feature = "experimental-rust-tls")]
{
epics_base_rs::runtime::task::spawn(transport::run_transport_manager(
transport_rx,
transport_evt_tx,
in_flight.clone(),
server_writers.clone(),
last_rx_at.clone(),
client_identity.clone(),
tls_arc,
config.tls_server_name.clone(),
sni_overrides,
))
}
#[cfg(not(feature = "experimental-rust-tls"))]
{
epics_base_rs::runtime::task::spawn(transport::run_transport_manager(
transport_rx,
transport_evt_tx,
in_flight.clone(),
server_writers.clone(),
last_rx_at.clone(),
client_identity.clone(),
))
}
};
let diagnostics = Arc::new(CaDiagnostics::default());
let exception_slot: types::CaExceptionSlot = Arc::new(parking_lot::RwLock::new(None));
let (beacon_ctrl_tx, beacon_ctrl_rx) =
mpsc::unbounded_channel::<beacon_monitor::BeaconControl>();
let coordinator = epics_base_rs::runtime::task::spawn(run_coordinator(
coord_rx,
search_resp_rx,
transport_evt_rx,
search_tx.clone(),
transport_tx.clone(),
in_flight.clone(),
cid_alloc.clone(),
snapshots.clone(),
last_rx_at,
diagnostics.clone(),
exception_slot.clone(),
search_attempts.clone(),
beacon_ctrl_tx,
));
let beacon_task = epics_base_rs::runtime::task::spawn(beacon_monitor::run_beacon_monitor(
coord_tx.clone(),
beacon_ctrl_rx,
));
Ok(Self {
search_tx,
transport_tx,
coord_tx,
one_shot_channels: Mutex::new(OneShotChannelCache::default()),
in_flight,
cid_alloc,
snapshots,
server_writers,
diagnostics,
search_attempts,
exception_slot,
client_identity,
_coordinator: coordinator,
_search_task: search_task,
_transport_task: transport_task,
_beacon_task: beacon_task,
_discovery_backends: backends,
_discovery_forwarders: discovery_forwarders,
})
}
pub fn diagnostics(&self) -> DiagnosticsSnapshot {
self.diagnostics.snapshot()
}
pub fn set_exception_handler<F>(&self, f: F) -> Option<types::CaExceptionHandler>
where
F: Fn(&types::CaException) + Send + Sync + 'static,
{
let new = Arc::new(f);
let mut slot = self.exception_slot.write();
slot.replace(new)
}
pub fn clear_exception_handler(&self) -> Option<types::CaExceptionHandler> {
self.exception_slot.write().take()
}
pub fn set_user_name(&self, user: impl Into<String>) {
self.client_identity.write().user = user.into();
}
pub fn set_host_name(&self, host: impl Into<String>) {
self.client_identity.write().host = host.into();
}
pub async fn ioc_connection_count(&self) -> usize {
let (tx, rx) = oneshot::channel();
if self
.coord_tx
.send(CoordRequest::GetIocConnectionCount { reply: tx })
.is_err()
{
return 0;
}
rx.await.unwrap_or(0)
}
pub async fn shutdown(&self) {
let (tx, rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::Shutdown { reply: tx });
let _ = tokio::time::timeout(Duration::from_secs(2), rx).await;
}
pub fn create_channel(&self, name: &str) -> CaChannel {
self.create_channel_expanded(expand_pv_name(name), 0)
}
pub fn create_channel_with_priority(&self, name: &str, priority: u8) -> CaChannel {
self.create_channel_expanded(expand_pv_name(name), priority.min(99))
}
fn create_channel_expanded(&self, pv_name: String, priority: u8) -> CaChannel {
let cid = self.cid_alloc.allocate();
let (conn_tx, _) = broadcast::channel(16);
let padded_len = crate::protocol::pad_string(&pv_name).len();
let name_empty = pv_name.is_empty();
let name_too_long = padded_len >= 0xFFFF;
let valid = !name_empty && !name_too_long;
if !valid {
tracing::warn!(
cid,
pv_name = %pv_name,
len = pv_name.len(),
padded_len,
reason = if name_empty { "empty" } else { "too long" },
"create_channel: invalid PV name rejected (libca ECA_BADSTR / ECA_UNAVAILINSERV parity); channel will not connect"
);
metrics::counter!("ca_client_create_channel_rejects_total").increment(1);
}
let _ = self.coord_tx.send(CoordRequest::RegisterChannel {
cid,
pv_name: pv_name.clone(),
priority,
conn_tx: conn_tx.clone(),
});
if valid {
let _ = self.search_tx.send(SearchRequest::Schedule {
cid,
pv_name: pv_name.clone(),
reason: SearchReason::Initial,
});
}
let lifecycle = Arc::new(ChannelLifecycle {
cid,
coord_tx: self.coord_tx.clone(),
});
let channel_pv_name: Arc<str> = Arc::from(pv_name.as_str());
CaChannel {
cid,
pv_name: channel_pv_name,
priority,
coord_tx: self.coord_tx.clone(),
transport_tx: self.transport_tx.clone(),
in_flight: self.in_flight.clone(),
snapshots: self.snapshots.clone(),
server_writers: self.server_writers.clone(),
conn_tx,
cached_read: Arc::new(Mutex::new(None)),
search_attempts: self.search_attempts.clone(),
user_data: Arc::new(Mutex::new(None)),
_lifecycle: lifecycle,
}
}
fn cached_one_shot_channel(&self, name: &str) -> CaChannel {
let pv_name = expand_pv_name(name);
self.one_shot_channels.lock().get_or_create(self, pv_name)
}
pub fn add_address(&self, addr: SocketAddr) {
let _ = self.search_tx.send(SearchRequest::AddAddress(addr));
}
pub fn set_address_list(&self, list: Vec<SocketAddr>) {
let _ = self.search_tx.send(SearchRequest::SetAddressList(list));
}
pub async fn caget(&self, pv_name: &str) -> CaResult<(DbFieldType, EpicsValue)> {
let ch = self.create_channel(pv_name);
ch.wait_connected(Duration::from_secs(3)).await?;
let result = ch.get().await;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
result
}
pub async fn caget_many<S>(&self, pv_names: &[S]) -> Vec<CaResult<(DbFieldType, EpicsValue)>>
where
S: AsRef<str> + Sync,
{
self.caget_many_with_timeout(pv_names, Duration::from_secs(30))
.await
}
pub async fn caget_many_with_timeout<S>(
&self,
pv_names: &[S],
timeout: Duration,
) -> Vec<CaResult<(DbFieldType, EpicsValue)>>
where
S: AsRef<str> + Sync,
{
let channels: Vec<CaChannel> = pv_names
.iter()
.map(|name| self.cached_one_shot_channel(name.as_ref()))
.collect();
let mut results = self.get_many_with_timeout(&channels, timeout).await;
let retry_indices: Vec<usize> = results
.iter()
.enumerate()
.filter_map(|(idx, result)| {
if matches!(
result,
Err(CaError::Disconnected) | Err(CaError::ChannelNotFound(_))
) {
Some(idx)
} else {
None
}
})
.collect();
if retry_indices.is_empty() {
return results;
}
let connected = futures_util::future::join_all(retry_indices.iter().map(|&idx| {
let channel = channels[idx].clone();
async move { (idx, channel.wait_connected(timeout).await) }
}))
.await;
let mut ready_indices = Vec::new();
let mut ready_channels = Vec::new();
for (idx, result) in connected {
match result {
Ok(()) => {
ready_indices.push(idx);
ready_channels.push(channels[idx].clone());
}
Err(e) => results[idx] = Err(e),
}
}
let read_results = self.get_many_with_timeout(&ready_channels, timeout).await;
for (idx, result) in ready_indices.into_iter().zip(read_results) {
results[idx] = result;
}
results
}
pub async fn get_many(
&self,
channels: &[CaChannel],
) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
self.get_many_with_timeout(channels, Duration::from_secs(30))
.await
}
pub async fn get_many_with_timeout(
&self,
channels: &[CaChannel],
timeout: Duration,
) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
CaChannel::get_many_with_timeout(channels, timeout).await
}
pub async fn caput(&self, pv_name: &str, value_str: &str) -> CaResult<()> {
let ch = self.create_channel(pv_name);
ch.wait_connected(Duration::from_secs(3)).await?;
let snap = ch.snapshot()?;
let value = EpicsValue::parse(snap.native_type, value_str)?;
ch.put_nowait(&value).await?;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
Ok(())
}
pub async fn caput_callback(
&self,
pv_name: &str,
value_str: &str,
timeout_secs: f64,
) -> CaResult<()> {
let ch = self.create_channel(pv_name);
let timeout = Duration::from_secs_f64(timeout_secs);
ch.wait_connected(timeout).await?;
let snap = ch.snapshot()?;
let value = EpicsValue::parse(snap.native_type, value_str)?;
ch.put_with_timeout(&value, timeout).await?;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
Ok(())
}
pub async fn cainfo(&self, pv_name: &str) -> CaResult<ChannelInfo> {
let ch = self.create_channel(pv_name);
ch.wait_connected(Duration::from_secs(3)).await?;
let info = ch.info().await;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
info
}
pub async fn camonitor<F>(&self, pv_name: &str, mut callback: F) -> CaResult<()>
where
F: FnMut(EpicsValue),
{
let ch = self.create_channel(pv_name);
let mut monitor = ch.subscribe().await?;
while let Some(result) = monitor.recv().await {
match result {
Ok(snap) => callback(snap.value),
Err(e) => return Err(e),
}
}
Ok(())
}
}
impl Drop for CaClient {
fn drop(&mut self) {
let coord_tx = self.coord_tx.clone();
let coord_abort = self._coordinator.abort_handle();
let search_abort = self._search_task.abort_handle();
let transport_abort = self._transport_task.abort_handle();
let beacon_abort = self._beacon_task.abort_handle();
for fwd in &self._discovery_forwarders {
fwd.abort();
}
if tokio::runtime::Handle::try_current().is_ok() {
tokio::spawn(async move {
let (tx, rx) = oneshot::channel();
if coord_tx.send(CoordRequest::Shutdown { reply: tx }).is_ok() {
let _ = tokio::time::timeout(Duration::from_secs(2), rx).await;
}
coord_abort.abort();
transport_abort.abort();
search_abort.abort();
beacon_abort.abort();
});
} else {
self._coordinator.abort();
self._transport_task.abort();
self._search_task.abort();
self._beacon_task.abort();
}
}
}
struct ChannelLifecycle {
cid: u32,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
}
impl Drop for ChannelLifecycle {
fn drop(&mut self) {
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: self.cid });
}
}
pub(crate) struct CachedRead {
pub(crate) ioid: u32,
pub(crate) sid: u32,
pub(crate) server_addr: SocketAddr,
pub(crate) data_type: u16,
pub(crate) element_count: u32,
pub(crate) slot: types::WarmReplySlot,
}
fn validate_put_count(snap: &types::ChannelSnapshotPublic, count: u32) -> CaResult<()> {
if count > snap.element_count {
return Err(CaError::Protocol(format!(
"put count {} exceeds channel element count {} \
(matches libca nciu::write ECA_BADCOUNT)",
count, snap.element_count
)));
}
Ok(())
}
const MAX_STRING_SIZE: usize = 40;
fn validate_string_length(s: impl AsRef<[u8]>) -> CaResult<()> {
let len = s.as_ref().len();
if len >= MAX_STRING_SIZE {
return Err(CaError::Protocol(format!(
"string of {len} bytes exceeds MAX_STRING_SIZE - 1 = 39 \
(matches libca nciu::stringVerify ECA_STRTOBIG)"
)));
}
Ok(())
}
fn validate_put_strings(value: &EpicsValue) -> CaResult<()> {
match value {
EpicsValue::String(s) => validate_string_length(s),
EpicsValue::StringArray(arr) => {
for s in arr {
validate_string_length(s)?;
}
Ok(())
}
_ => Ok(()),
}
}
#[derive(Clone)]
pub struct CaChannel {
cid: u32,
pv_name: Arc<str>,
priority: u8,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
transport_tx: mpsc::UnboundedSender<TransportCommand>,
in_flight: InFlightOps,
snapshots: ChannelSnapshots,
server_writers: DirectServerWriters,
conn_tx: broadcast::Sender<ConnectionEvent>,
cached_read: Arc<Mutex<Option<CachedRead>>>,
search_attempts: types::SearchAttempts,
user_data: Arc<Mutex<Option<Arc<dyn std::any::Any + Send + Sync>>>>,
_lifecycle: Arc<ChannelLifecycle>,
}
#[cfg(test)]
impl CaChannel {
fn for_test(coord_tx: mpsc::UnboundedSender<CoordRequest>) -> Self {
let (conn_tx, _) = broadcast::channel(1);
let (transport_tx, _) = mpsc::unbounded_channel();
let cid = 0;
let lifecycle = Arc::new(ChannelLifecycle {
cid,
coord_tx: coord_tx.clone(),
});
CaChannel {
cid,
pv_name: Arc::from("test:pv"),
priority: 0,
coord_tx,
transport_tx,
in_flight: InFlightOps::new(),
snapshots: ChannelSnapshots::default(),
server_writers: DirectServerWriters::default(),
conn_tx,
cached_read: Arc::new(Mutex::new(None)),
search_attempts: types::SearchAttempts::default(),
user_data: Arc::new(Mutex::new(None)),
_lifecycle: lifecycle,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ReqCount {
Fixed(u32),
Autosize(u32),
}
impl From<u32> for ReqCount {
fn from(count: u32) -> Self {
ReqCount::Fixed(count)
}
}
impl ReqCount {
pub fn resolve(self, native: u32) -> u32 {
match self {
ReqCount::Fixed(0) => native,
ReqCount::Fixed(n) | ReqCount::Autosize(n) => n,
}
}
}
impl CaChannel {
pub async fn wait_connected(&self, timeout: Duration) -> CaResult<()> {
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::WaitConnected {
cid: self.cid,
reply: reply_tx,
});
tokio::time::timeout(timeout, reply_rx)
.await
.map_err(|_| CaError::ChannelNotFound(self.pv_name.to_string()))?
.map_err(|_| CaError::Shutdown)
}
pub async fn info(&self) -> CaResult<ChannelInfo> {
let snap = self.snapshot()?;
Ok(ChannelInfo {
pv_name: self.pv_name.to_string(),
server_addr: snap.server_addr,
native_type: snap.native_type,
element_count: snap.element_count,
access_rights: snap.access_rights,
})
}
pub fn search_attempts(&self) -> u32 {
self.search_attempts
.get(&self.cid)
.map(|e| e.load(std::sync::atomic::Ordering::Relaxed))
.unwrap_or(0)
}
fn snapshot(&self) -> CaResult<ChannelSnapshotPublic> {
match self.snapshots.get(&self.cid) {
Some(s) if s.state.is_operational() => Ok(s.clone()),
_ => Err(CaError::Disconnected),
}
}
fn direct_writer(&self, server_addr: SocketAddr) -> Option<DirectServerWriter> {
self.server_writers
.get(&(server_addr, self.priority))
.map(|w| w.clone())
}
fn build_read_notify_frame(sid: u32, data_type: u16, count: u32, ioid: u32) -> Vec<u8> {
let mut hdr = CaHeader::new(CA_PROTO_READ_NOTIFY);
hdr.data_type = data_type;
hdr.cid = sid;
hdr.available = ioid;
if count >= 0xFFFF {
hdr.set_payload_size(0, count);
} else {
hdr.count = count as u16;
}
hdr.to_bytes_extended()
}
fn decode_plain_read_reply(reply: ReadReply) -> CaResult<(DbFieldType, EpicsValue)> {
match reply {
ReadReply::Plain { dbr_type, value } => Ok((dbr_type, value)),
ReadReply::Raw {
data_type,
count,
data,
} => {
let dbr_type = DbFieldType::from_u16(data_type)?;
EpicsValue::from_bytes_array(dbr_type, &data, count as usize)
.map(|value| (dbr_type, value))
}
}
}
fn build_write_frame(
cmd: u16,
sid: u32,
data_type: u16,
count: u32,
ioid: Option<u32>,
payload: Vec<u8>,
) -> Vec<u8> {
let padded_len = align8(payload.len());
let mut padded = payload;
padded.resize(padded_len, 0);
let mut hdr = CaHeader::new(cmd);
hdr.data_type = data_type;
hdr.cid = sid;
if let Some(ioid) = ioid {
hdr.available = ioid;
}
hdr.set_payload_size(padded.len(), count);
let mut frame = hdr.to_bytes_extended();
frame.extend_from_slice(&padded);
frame
}
fn send_read_notify_fast(
&self,
snap: &ChannelSnapshotPublic,
data_type: u16,
count: u32,
ioid: u32,
) -> CaResult<()> {
if !snap.access_rights.read {
return Err(CaError::Protocol(format!(
"read denied by cached access rights (matches libca \
nciu::read ECA_NORDACCESS); ioid {ioid}"
)));
}
if let Some(writer) = self.direct_writer(snap.server_addr) {
return writer.send_frame(Self::build_read_notify_frame(
snap.sid, data_type, count, ioid,
));
}
self.transport_tx
.send(TransportCommand::ReadNotify {
sid: snap.sid,
data_type,
count,
ioid,
server_addr: snap.server_addr,
priority: self.priority,
})
.map_err(|_| CaError::Shutdown)
}
fn send_write_notify_fast(
&self,
snap: &ChannelSnapshotPublic,
data_type: u16,
count: u32,
ioid: u32,
payload: Vec<u8>,
) -> CaResult<()> {
if !snap.access_rights.write {
return Err(CaError::Protocol(format!(
"write denied by cached access rights (matches libca \
nciu::write ECA_NOWTACCESS); ioid {ioid}"
)));
}
if let Some(writer) = self.direct_writer(snap.server_addr) {
return writer.send_frame(Self::build_write_frame(
CA_PROTO_WRITE_NOTIFY,
snap.sid,
data_type,
count,
Some(ioid),
payload,
));
}
self.transport_tx
.send(TransportCommand::WriteNotify {
sid: snap.sid,
data_type,
count,
ioid,
payload,
server_addr: snap.server_addr,
priority: self.priority,
})
.map_err(|_| CaError::Shutdown)
}
fn send_write_nowait_fast(
&self,
snap: &ChannelSnapshotPublic,
data_type: u16,
count: u32,
payload: Vec<u8>,
) -> CaResult<()> {
if !snap.access_rights.write {
return Err(CaError::Protocol(
"write denied by cached access rights (matches libca \
nciu::write ECA_NOWTACCESS)"
.into(),
));
}
if let Some(writer) = self.direct_writer(snap.server_addr) {
return writer.send_frame(Self::build_write_frame(
CA_PROTO_WRITE,
snap.sid,
data_type,
count,
None,
payload,
));
}
self.transport_tx
.send(TransportCommand::Write {
sid: snap.sid,
data_type,
count,
payload,
server_addr: snap.server_addr,
priority: self.priority,
})
.map_err(|_| CaError::Shutdown)
}
pub async fn get(&self) -> CaResult<(DbFieldType, EpicsValue)> {
self.get_with_timeout(Duration::from_secs(30)).await
}
pub async fn get_many(channels: &[CaChannel]) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
Self::get_many_with_timeout(channels, Duration::from_secs(30)).await
}
pub async fn get_many_with_timeout(
channels: &[CaChannel],
timeout: Duration,
) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
enum PendingKind {
Cold {
ioid: u32,
in_flight: InFlightOps,
cid: u32,
cached_read_slot: Arc<Mutex<Option<CachedRead>>>,
sid: u32,
server_addr: SocketAddr,
data_type: u16,
element_count: u32,
},
Warm {
ioid: u32,
in_flight: InFlightOps,
cached_read_slot: Arc<Mutex<Option<CachedRead>>>,
cached: CachedRead,
},
}
struct Pending {
index: usize,
reply_rx: oneshot::Receiver<CaResult<ReadReply>>,
kind: PendingKind,
}
struct BulkReadGroup {
writer: DirectServerWriter,
frame: Vec<u8>,
pending: Vec<Pending>,
}
let mut results: Vec<Option<CaResult<(DbFieldType, EpicsValue)>>> =
(0..channels.len()).map(|_| None).collect();
let mut groups: HashMap<types::CircuitKey, BulkReadGroup> = HashMap::new();
let mut pending: Vec<Pending> = Vec::new();
for (index, ch) in channels.iter().enumerate() {
let snap = match ch.snapshot() {
Ok(s) => s,
Err(e) => {
results[index] = Some(Err(e));
continue;
}
};
let warm_taken: Option<CachedRead> = {
let mut guard = ch.cached_read.lock();
let matches = matches!(guard.as_ref(), Some(c)
if c.server_addr == snap.server_addr
&& c.sid == snap.sid
&& c.data_type == snap.native_type as u16
&& c.element_count == snap.element_count);
if matches {
guard.take()
} else if guard.is_some() {
let stale = guard.take().unwrap();
ch.in_flight.reads.remove(&stale.ioid);
None
} else {
None
}
};
let (reply_tx, reply_rx) = oneshot::channel();
let (frame, kind) = if let Some(cached) = warm_taken {
*cached.slot.lock() = Some(reply_tx);
let frame = Self::build_read_notify_frame(
cached.sid,
cached.data_type,
cached.element_count,
cached.ioid,
);
let kind = PendingKind::Warm {
ioid: cached.ioid,
in_flight: ch.in_flight.clone(),
cached_read_slot: ch.cached_read.clone(),
cached,
};
(frame, kind)
} else {
let ioid = ch.in_flight.alloc_ioid();
ch.in_flight.reads.insert(
ioid,
ReadWaiter::OneShot {
cid: ch.cid,
mode: ReadReplyMode::Plain,
reply_tx,
},
);
let frame = Self::build_read_notify_frame(
snap.sid,
snap.native_type as u16,
snap.element_count,
ioid,
);
let kind = PendingKind::Cold {
ioid,
in_flight: ch.in_flight.clone(),
cid: ch.cid,
cached_read_slot: ch.cached_read.clone(),
sid: snap.sid,
server_addr: snap.server_addr,
data_type: snap.native_type as u16,
element_count: snap.element_count,
};
(frame, kind)
};
let pending_read = Pending {
index,
reply_rx,
kind,
};
if let Some(writer) = ch.direct_writer(snap.server_addr) {
let group = groups
.entry((snap.server_addr, ch.priority))
.or_insert_with(|| BulkReadGroup {
writer,
frame: Vec::new(),
pending: Vec::new(),
});
group.frame.extend_from_slice(&frame);
group.pending.push(pending_read);
} else {
match pending_read.kind {
PendingKind::Cold {
ioid, in_flight, ..
} => match ch.send_read_notify_fast(
&snap,
snap.native_type as u16,
snap.element_count,
ioid,
) {
Ok(()) => pending.push(Pending {
index,
reply_rx: pending_read.reply_rx,
kind: PendingKind::Cold {
ioid,
in_flight,
cid: ch.cid,
cached_read_slot: ch.cached_read.clone(),
sid: snap.sid,
server_addr: snap.server_addr,
data_type: snap.native_type as u16,
element_count: snap.element_count,
},
}),
Err(e) => {
in_flight.reads.remove(&ioid);
results[index] = Some(Err(e));
}
},
PendingKind::Warm {
ioid,
in_flight,
cached_read_slot,
..
} => {
in_flight.reads.remove(&ioid);
*cached_read_slot.lock() = None;
results[index] = Some(Err(CaError::Disconnected));
}
}
}
}
for (_, group) in groups {
match group.writer.send_frame(group.frame) {
Ok(()) => pending.extend(group.pending),
Err(_) => {
for p in group.pending {
match p.kind {
PendingKind::Cold {
ioid, in_flight, ..
} => {
in_flight.reads.remove(&ioid);
}
PendingKind::Warm {
ioid,
in_flight,
cached_read_slot,
..
} => {
in_flight.reads.remove(&ioid);
*cached_read_slot.lock() = None;
}
}
results[p.index] = Some(Err(CaError::Disconnected));
}
}
}
}
let deadline = tokio::time::Instant::now() + timeout;
for p in pending {
let Pending {
index,
reply_rx,
kind,
} = p;
let result = tokio::time::timeout_at(deadline, reply_rx).await;
let decoded: CaResult<(DbFieldType, EpicsValue)> = match result {
Ok(Ok(Ok(reply))) => Self::decode_plain_read_reply(reply),
Ok(Ok(Err(e))) => Err(e),
Ok(Err(_)) => Err(CaError::Shutdown),
Err(_) => Err(CaError::Timeout),
};
let is_local_error = matches!(decoded, Err(CaError::Timeout) | Err(CaError::Shutdown));
match kind {
PendingKind::Cold {
ioid,
in_flight,
cid,
cached_read_slot,
sid,
server_addr,
data_type,
element_count,
} => {
if is_local_error {
in_flight.reads.remove(&ioid);
}
if decoded.is_ok() {
let warm_ioid = in_flight.alloc_ioid();
let slot: types::WarmReplySlot = Arc::new(parking_lot::Mutex::new(None));
in_flight.reads.insert(
warm_ioid,
ReadWaiter::Warm {
cid,
mode: ReadReplyMode::Plain,
slot: slot.clone(),
},
);
let cached = CachedRead {
ioid: warm_ioid,
sid,
server_addr,
data_type,
element_count,
slot,
};
let mut guard = cached_read_slot.lock();
if guard.is_none() {
*guard = Some(cached);
} else {
drop(guard);
in_flight.reads.remove(&warm_ioid);
}
}
}
PendingKind::Warm {
ioid,
in_flight,
cached_read_slot,
cached,
} => {
if is_local_error {
in_flight.reads.remove(&ioid);
drop(cached);
*cached_read_slot.lock() = None;
} else {
let mut guard = cached_read_slot.lock();
if guard.is_none() {
*guard = Some(cached);
} else {
drop(guard);
in_flight.reads.remove(&ioid);
}
}
}
}
results[index] = Some(decoded);
}
results
.into_iter()
.map(|r| r.unwrap_or(Err(CaError::Shutdown)))
.collect()
}
pub fn native_field_type(&self) -> CaResult<DbFieldType> {
Ok(self.snapshot()?.native_type)
}
pub fn element_count(&self) -> CaResult<u32> {
Ok(self.snapshot()?.element_count)
}
pub async fn get_with_timeout(&self, timeout: Duration) -> CaResult<(DbFieldType, EpicsValue)> {
self.get_with_timeout_count(timeout, 0).await
}
pub async fn get_with_timeout_count(
&self,
timeout: Duration,
count: impl Into<ReqCount>,
) -> CaResult<(DbFieldType, EpicsValue)> {
let snap = self.snapshot()?;
let request_count = count.into().resolve(snap.element_count);
if request_count > snap.element_count {
return Err(CaError::Protocol(format!(
"get count {} exceeds channel element count {} \
(matches libca nciu::read ECA_BADCOUNT)",
request_count, snap.element_count
)));
}
let ioid = self.in_flight.alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.reads.insert(
ioid,
ReadWaiter::OneShot {
cid: self.cid,
mode: ReadReplyMode::Plain,
reply_tx,
},
);
if let Err(e) =
self.send_read_notify_fast(&snap, snap.native_type as u16, request_count, ioid)
{
self.in_flight.reads.remove(&ioid);
return Err(e);
}
let result = tokio::time::timeout(timeout, reply_rx).await;
self.in_flight.reads.remove(&ioid);
let reply = result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)??;
Self::decode_plain_read_reply(reply)
}
pub async fn get_with_metadata(&self, class: DbrClass) -> CaResult<Snapshot> {
self.get_with_metadata_count(class, 0).await
}
pub async fn get_with_metadata_count(
&self,
class: DbrClass,
count: impl Into<ReqCount>,
) -> CaResult<Snapshot> {
let snap = self.snapshot()?;
let native = DbFieldType::from_u16(snap.native_type as u16)?;
let request_type = match class {
DbrClass::Time => native.time_dbr_type(),
DbrClass::Ctrl => native.ctrl_dbr_type(),
DbrClass::Sts => native.sts_dbr_type(),
DbrClass::Gr => native.gr_dbr_type(),
DbrClass::Plain => native.to_dbr_type() as u16,
};
self.get_dbr_request(snap, request_type, count.into()).await
}
pub async fn get_with_dbr_type(
&self,
dbr_type: u16,
count: impl Into<ReqCount>,
) -> CaResult<Snapshot> {
let snap = self.snapshot()?;
self.get_dbr_request(snap, dbr_type, count.into()).await
}
async fn get_dbr_request(
&self,
snap: ChannelSnapshotPublic,
request_type: u16,
count: ReqCount,
) -> CaResult<Snapshot> {
let request_count = count.resolve(snap.element_count);
if request_count > snap.element_count {
return Err(CaError::Protocol(format!(
"get count {} exceeds channel element count {} \
(matches libca nciu::read ECA_BADCOUNT)",
request_count, snap.element_count
)));
}
let ioid = self.in_flight.alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.reads.insert(
ioid,
ReadWaiter::OneShot {
cid: self.cid,
mode: ReadReplyMode::Raw,
reply_tx,
},
);
if let Err(e) = self.send_read_notify_fast(&snap, request_type, request_count, ioid) {
self.in_flight.reads.remove(&ioid);
return Err(e);
}
let result = tokio::time::timeout(Duration::from_secs(30), reply_rx).await;
self.in_flight.reads.remove(&ioid);
let reply = result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)??;
match reply {
ReadReply::Raw {
data_type,
count,
data,
} => decode_dbr(data_type, &data, count as usize),
ReadReply::Plain { .. } => Err(CaError::Protocol(
"metadata read returned a plain scalar reply".into(),
)),
}
}
pub async fn put(&self, value: &EpicsValue) -> CaResult<()> {
let snap = self.snapshot()?;
validate_put_count(&snap, value.count())?;
validate_put_strings(value)?;
let ioid = self.in_flight.alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.writes.insert(ioid, (self.cid, reply_tx));
let payload = value.to_bytes();
let count = value.count() as u32;
if let Err(e) =
self.send_write_notify_fast(&snap, snap.native_type as u16, count, ioid, payload)
{
self.in_flight.writes.remove(&ioid);
return Err(e);
}
let default_secs = epics_base_rs::runtime::env::get("EPICS_CA_PUT_TIMEOUT")
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(30.0);
let result = tokio::time::timeout(Duration::from_secs_f64(default_secs), reply_rx).await;
self.in_flight.writes.remove(&ioid);
result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)?
}
pub async fn put_with_timeout(&self, value: &EpicsValue, timeout: Duration) -> CaResult<()> {
let snap = self.snapshot()?;
validate_put_count(&snap, value.count())?;
validate_put_strings(value)?;
let ioid = self.in_flight.alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.writes.insert(ioid, (self.cid, reply_tx));
let payload = value.to_bytes();
let count = value.count() as u32;
if let Err(e) =
self.send_write_notify_fast(&snap, snap.native_type as u16, count, ioid, payload)
{
self.in_flight.writes.remove(&ioid);
return Err(e);
}
let result = tokio::time::timeout(timeout, reply_rx).await;
self.in_flight.writes.remove(&ioid);
result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)?
}
pub async fn put_nowait(&self, value: &EpicsValue) -> CaResult<()> {
let snap = self.snapshot()?;
validate_put_count(&snap, value.count())?;
validate_put_strings(value)?;
let payload = value.to_bytes();
let count = value.count() as u32;
self.send_write_nowait_fast(&snap, snap.native_type as u16, count, payload)
}
pub async fn put_as_dbr_with_timeout(
&self,
dbr_type: u16,
value: &EpicsValue,
timeout: Duration,
) -> CaResult<()> {
let snap = self.snapshot()?;
let count = value.count() as u32;
validate_put_count(&snap, count)?;
validate_put_strings(value)?;
let ioid = self.in_flight.alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.writes.insert(ioid, (self.cid, reply_tx));
let payload = value.to_bytes();
if let Err(e) = self.send_write_notify_fast(&snap, dbr_type, count, ioid, payload) {
self.in_flight.writes.remove(&ioid);
return Err(e);
}
let result = tokio::time::timeout(timeout, reply_rx).await;
self.in_flight.writes.remove(&ioid);
result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)?
}
pub async fn put_as_dbr_nowait(&self, dbr_type: u16, value: &EpicsValue) -> CaResult<()> {
let snap = self.snapshot()?;
let count = value.count() as u32;
validate_put_count(&snap, count)?;
validate_put_strings(value)?;
let payload = value.to_bytes();
self.send_write_nowait_fast(&snap, dbr_type, count, payload)
}
pub async fn put_string(&self, value: &str) -> CaResult<()> {
let snap = self.snapshot()?;
validate_put_count(&snap, 1)?;
validate_string_length(value)?;
let ioid = self.in_flight.alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.writes.insert(ioid, (self.cid, reply_tx));
let payload = EpicsValue::String(value.into()).to_bytes();
if let Err(e) = self.send_write_notify_fast(&snap, 0, 1, ioid, payload) {
self.in_flight.writes.remove(&ioid);
return Err(e);
}
let default_secs = epics_base_rs::runtime::env::get("EPICS_CA_PUT_TIMEOUT")
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(30.0);
let result = tokio::time::timeout(Duration::from_secs_f64(default_secs), reply_rx).await;
self.in_flight.writes.remove(&ioid);
result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)?
}
pub async fn put_string_nowait(&self, value: &str) -> CaResult<()> {
let snap = self.snapshot()?;
validate_put_count(&snap, 1)?;
validate_string_length(value)?;
let payload = EpicsValue::String(value.into()).to_bytes();
self.send_write_nowait_fast(&snap, 0, 1, payload)
}
pub async fn put_string_array(&self, values: &[String]) -> CaResult<()> {
let snap = self.snapshot()?;
validate_put_count(&snap, values.len() as u32)?;
for s in values {
validate_string_length(s)?;
}
let ioid = self.in_flight.alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.writes.insert(ioid, (self.cid, reply_tx));
let payload =
EpicsValue::StringArray(values.iter().map(PvString::from).collect()).to_bytes();
let count = values.len() as u32;
if let Err(e) = self.send_write_notify_fast(&snap, 0, count, ioid, payload) {
self.in_flight.writes.remove(&ioid);
return Err(e);
}
let default_secs = epics_base_rs::runtime::env::get("EPICS_CA_PUT_TIMEOUT")
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(30.0);
let result = tokio::time::timeout(Duration::from_secs_f64(default_secs), reply_rx).await;
self.in_flight.writes.remove(&ioid);
result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)?
}
pub async fn put_string_array_nowait(&self, values: &[String]) -> CaResult<()> {
let snap = self.snapshot()?;
validate_put_count(&snap, values.len() as u32)?;
for s in values {
validate_string_length(s)?;
}
let payload =
EpicsValue::StringArray(values.iter().map(PvString::from).collect()).to_bytes();
let count = values.len() as u32;
self.send_write_nowait_fast(&snap, 0, count, payload)
}
pub async fn subscribe(&self) -> CaResult<MonitorHandle> {
self.subscribe_with_deadband(0.0).await
}
pub async fn subscribe_with_deadband(&self, deadband: f64) -> CaResult<MonitorHandle> {
self.subscribe_with_mask(deadband, DBE_VALUE | DBE_LOG | DBE_ALARM)
.await
}
pub async fn subscribe_with_mask(&self, deadband: f64, mask: u16) -> CaResult<MonitorHandle> {
self.subscribe_with_mask_enum_as_string(deadband, mask, false)
.await
}
pub async fn subscribe_with_mask_autosize(
&self,
deadband: f64,
mask: u16,
) -> CaResult<MonitorHandle> {
self.subscribe_with_mask_readback_count(deadband, mask, EnumReadback::Native, false, None)
.await
}
pub async fn subscribe_with_mask_enum_as_string(
&self,
deadband: f64,
mask: u16,
enum_as_string: bool,
) -> CaResult<MonitorHandle> {
self.subscribe_with_mask_readback(deadband, mask, enum_as_string, false)
.await
}
pub async fn subscribe_with_mask_readback(
&self,
deadband: f64,
mask: u16,
enum_as_string: bool,
float_as_string: bool,
) -> CaResult<MonitorHandle> {
let enum_readback = if enum_as_string {
EnumReadback::Label
} else {
EnumReadback::Native
};
self.subscribe_with_mask_readback_count(
deadband,
mask,
enum_readback,
float_as_string,
None,
)
.await
}
pub async fn subscribe_with_mask_readback_count(
&self,
deadband: f64,
mask: u16,
enum_readback: EnumReadback,
float_as_string: bool,
req_count: Option<u32>,
) -> CaResult<MonitorHandle> {
let env = epics_base_rs::runtime::env::get("EPICS_CA_MONITOR_QUEUE")
.and_then(|s| s.parse::<usize>().ok());
let queue_size = resolve_monitor_queue_size(env);
let (callback_tx, callback_rx) = mpsc::channel(queue_size);
let coalesce_slot = subscription::CoalesceSlot::new();
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::Subscribe {
cid: self.cid,
mask,
deadband,
callback_tx,
coalesce_slot: coalesce_slot.clone(),
req_count,
enum_readback,
float_as_string,
reply: reply_tx,
});
let subid = reply_rx.await.map_err(|_| CaError::Shutdown)??;
Ok(MonitorHandle {
subid,
callback_rx,
coalesce_slot,
coord_tx: self.coord_tx.clone(),
channel: self.clone(),
})
}
pub fn connection_events(&self) -> broadcast::Receiver<ConnectionEvent> {
self.conn_tx.subscribe()
}
pub fn on_access_rights_change<F>(&self, mut cb: F) -> EventWatcher
where
F: FnMut(AccessRights) + Send + 'static,
{
let mut rx = self.conn_tx.subscribe();
let handle = epics_base_rs::runtime::task::spawn(async move {
while let Ok(evt) = rx.recv().await {
if let ConnectionEvent::AccessRightsChanged { read, write } = evt {
cb(AccessRights { read, write });
}
}
});
EventWatcher { handle }
}
pub fn on_connection_change<F>(&self, mut cb: F) -> EventWatcher
where
F: FnMut(bool) + Send + 'static,
{
let mut rx = self.conn_tx.subscribe();
let handle = epics_base_rs::runtime::task::spawn(async move {
while let Ok(evt) = rx.recv().await {
match evt {
ConnectionEvent::Connected => cb(true),
ConnectionEvent::Disconnected => cb(false),
_ => {}
}
}
});
EventWatcher { handle }
}
pub async fn host_name(&self) -> CaResult<String> {
let info = self.info().await?;
Ok(info.server_addr.to_string())
}
pub async fn host_minor_protocol(&self) -> Option<u16> {
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::GetHostMinorProtocol {
cid: self.cid,
reply: reply_tx,
});
reply_rx.await.ok().flatten()
}
pub async fn v42_ok(&self) -> bool {
matches!(self.host_minor_protocol().await, Some(v) if v >= 2)
}
pub fn set_user_data<T: std::any::Any + Send + Sync>(&self, data: Arc<T>) {
*self.user_data.lock() = Some(data);
}
pub fn user_data<T: std::any::Any + Send + Sync>(&self) -> Option<Arc<T>> {
self.user_data
.lock()
.as_ref()
.and_then(|any| Arc::clone(any).downcast::<T>().ok())
}
pub fn clear_user_data(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
self.user_data.lock().take()
}
pub async fn receive_watchdog_delay(&self) -> Duration {
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::GetWatchdogDelay {
cid: self.cid,
reply: reply_tx,
});
match reply_rx.await {
Ok(Some(d)) => d,
_ => Duration::ZERO,
}
}
}
pub fn ca_version() -> String {
format!(
"epics-ca-rs {} (CA v4.{})",
env!("CARGO_PKG_VERSION"),
crate::protocol::CA_MINOR_VERSION,
)
}
pub struct MonitorHandle {
subid: u32,
callback_rx: mpsc::Receiver<CaResult<Snapshot>>,
coalesce_slot: std::sync::Arc<subscription::CoalesceSlot>,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
channel: CaChannel,
}
impl MonitorHandle {
pub fn channel(&self) -> CaChannel {
self.channel.clone()
}
pub fn subid(&self) -> u32 {
self.subid
}
pub async fn recv(&mut self) -> Option<CaResult<Snapshot>> {
use tokio::sync::mpsc::error::TryRecvError;
loop {
match self.callback_rx.try_recv() {
Ok(msg) => {
let _ = self
.coord_tx
.send(CoordRequest::MonitorConsumed { subid: self.subid });
return Some(msg);
}
Err(TryRecvError::Disconnected) => return None,
Err(TryRecvError::Empty) => {}
}
if let Some(msg) = self.coalesce_slot.take_deliverable() {
return Some(msg);
}
let notified = self.coalesce_slot.notified();
tokio::pin!(notified);
tokio::select! {
msg = self.callback_rx.recv() => {
if msg.is_some() {
let _ = self
.coord_tx
.send(CoordRequest::MonitorConsumed { subid: self.subid });
}
return msg;
}
_ = &mut notified => {
}
}
}
}
pub fn pause(&self) {
self.coalesce_slot.set_paused(true);
}
pub fn resume(&self) {
self.coalesce_slot.set_paused(false);
}
pub fn is_paused(&self) -> bool {
self.coalesce_slot.is_paused()
}
}
impl Drop for MonitorHandle {
fn drop(&mut self) {
let _ = self
.coord_tx
.send(CoordRequest::Unsubscribe { subid: self.subid });
}
}
#[must_use = "dropping the EventWatcher immediately stops the watcher task; \
bind it to a variable to keep watching"]
pub struct EventWatcher {
handle: tokio::task::JoinHandle<()>,
}
impl EventWatcher {
pub fn abort(self) {
}
}
impl Drop for EventWatcher {
fn drop(&mut self) {
self.handle.abort();
}
}
const FLOW_CONTROL_OFF_THRESHOLD: usize = 10;
fn resolve_monitor_queue_size(env: Option<usize>) -> usize {
env.unwrap_or(256).max(FLOW_CONTROL_OFF_THRESHOLD)
}
const FLOW_CONTROL_ON_THRESHOLD: usize = 5;
#[derive(Default)]
struct FlowControlState {
outstanding: usize,
active: bool,
}
fn flow_control_note_queued(
flow_control: &mut HashMap<types::CircuitKey, FlowControlState>,
circuit: types::CircuitKey,
transport_tx: &mpsc::UnboundedSender<TransportCommand>,
) {
let (server_addr, priority) = circuit;
let state = flow_control.entry(circuit).or_default();
state.outstanding = state.outstanding.saturating_add(1);
if !state.active && state.outstanding >= FLOW_CONTROL_OFF_THRESHOLD {
let _ = transport_tx.send(TransportCommand::EventsOff {
server_addr,
priority,
});
state.active = true;
}
}
fn flow_control_note_consumed(
flow_control: &mut HashMap<types::CircuitKey, FlowControlState>,
circuit: types::CircuitKey,
count: usize,
transport_tx: &mpsc::UnboundedSender<TransportCommand>,
) {
if count == 0 {
return;
}
let (server_addr, priority) = circuit;
let Some(state) = flow_control.get_mut(&circuit) else {
return;
};
state.outstanding = state.outstanding.saturating_sub(count);
if state.active && state.outstanding <= FLOW_CONTROL_ON_THRESHOLD {
let _ = transport_tx.send(TransportCommand::EventsOn {
server_addr,
priority,
});
state.active = false;
}
if !state.active && state.outstanding == 0 {
flow_control.remove(&circuit);
}
}
#[allow(clippy::too_many_arguments)]
async fn run_coordinator(
mut coord_rx: mpsc::UnboundedReceiver<CoordRequest>,
mut search_rx: mpsc::UnboundedReceiver<SearchResponse>,
mut transport_rx: mpsc::UnboundedReceiver<TransportEvent>,
search_tx: mpsc::UnboundedSender<SearchRequest>,
transport_tx: mpsc::UnboundedSender<TransportCommand>,
in_flight: types::InFlightOps,
cid_alloc: types::CidAllocator,
snapshots: ChannelSnapshots,
last_rx_at: ServerLastRxAt,
diag: Arc<CaDiagnostics>,
exception_slot: types::CaExceptionSlot,
search_attempts: types::SearchAttempts,
beacon_ctrl_tx: mpsc::UnboundedSender<beacon_monitor::BeaconControl>,
) {
let mut channels: HashMap<u32, ChannelInner> = HashMap::new();
let mut pending_wait_connected: HashMap<u32, Vec<oneshot::Sender<()>>> = HashMap::new();
let mut pending_found: HashMap<u32, SocketAddr> = HashMap::new();
let mut subscriptions = SubscriptionRegistry::new();
let mut server_channels: HashMap<types::CircuitKey, HashSet<u32>> = HashMap::new();
let mut flow_control: HashMap<types::CircuitKey, FlowControlState> = HashMap::new();
let mut server_minor_version: HashMap<types::CircuitKey, u16> = HashMap::new();
loop {
tokio::select! {
req = coord_rx.recv() => {
let Some(req) = req else { return };
match req {
CoordRequest::RegisterChannel { cid, pv_name, priority, conn_tx } => {
let early_waiters = pending_wait_connected
.remove(&cid)
.unwrap_or_default();
channels.insert(cid, ChannelInner {
cid,
pv_name: pv_name.clone(),
priority,
state: ChannelState::Searching,
sid: 0,
native_type: None,
element_count: 0,
server_addr: None,
access_rights: AccessRights::from_u32(0),
connect_waiters: early_waiters,
conn_tx,
reconnect_count: 0,
last_connected_at: None,
});
if let Some(server_addr) = pending_found.remove(&cid) {
let ch = channels.get_mut(&cid).unwrap();
ch.state = ChannelState::Connecting;
ch.server_addr = Some(server_addr);
server_channels.entry((server_addr, priority)).or_default().insert(cid);
let _ = transport_tx.send(TransportCommand::CreateChannel {
cid,
pv_name,
server_addr,
priority,
});
}
}
CoordRequest::WaitConnected { cid, reply } => {
if let Some(ch) = channels.get_mut(&cid) {
if ch.state == ChannelState::Connected {
let _ = reply.send(());
} else {
ch.connect_waiters.push(reply);
}
} else {
pending_wait_connected
.entry(cid)
.or_default()
.push(reply);
}
}
CoordRequest::Subscribe { cid, mask, deadband, callback_tx, coalesce_slot, req_count, enum_readback, float_as_string, reply } => {
if let Some(ch) = channels.get(&cid) {
let server_addr = ch.server_addr.unwrap_or_else(|| {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
});
let priority = ch.priority;
let connected = ch.state == ChannelState::Connected;
let data_type = ch
.native_type
.map(|t| subscription_readback_dbr(t, enum_readback, float_as_string));
let count = ch
.native_type
.map(|_| subscription::resolve_subscription_count(req_count, ch.element_count));
let subid = subscriptions.alloc_subid();
subscriptions.add(subscription::SubscriptionRecord {
subid,
cid,
data_type,
count,
req_count,
type_user_supplied: false,
enum_readback,
float_as_string,
mask,
server_addr,
priority,
deadband,
callback_tx,
coalesce_slot,
needs_restore: !connected,
last_value: None,
pending_deliveries: 0,
nreplace: 0,
});
if connected {
let _ = transport_tx.send(TransportCommand::Subscribe {
sid: ch.sid,
data_type: data_type.expect("connected channel has native type"),
count: count.expect("connected channel has element count"),
subid,
mask,
server_addr,
priority,
});
}
let _ = reply.send(Ok(subid));
} else {
let _ = reply.send(Err(CaError::Disconnected));
}
}
CoordRequest::Unsubscribe { subid } => {
if let Some(rec) = subscriptions.get(subid) {
let cid = rec.cid;
if let Some(ch) = channels.get(&cid) {
if ch.state == ChannelState::Connected {
if let Some(data_type) = rec.data_type {
let _ = transport_tx.send(TransportCommand::Unsubscribe {
sid: ch.sid,
subid,
data_type,
count: rec.count.unwrap_or(0),
server_addr: ch.server_addr.unwrap(),
priority: ch.priority,
});
}
}
}
}
if let Some(rec) = subscriptions.remove(subid) {
flow_control_note_consumed(
&mut flow_control,
(rec.server_addr, rec.priority),
rec.pending_deliveries,
&transport_tx,
);
}
}
CoordRequest::MonitorConsumed { subid } => {
if let Some(circuit) = subscriptions.mark_consumed(subid) {
flow_control_note_consumed(
&mut flow_control,
circuit,
1,
&transport_tx,
);
}
}
CoordRequest::DropChannel { cid } => {
let sub_ids = subscriptions.for_cid(cid);
for subid in sub_ids {
if let Some(rec) = subscriptions.get(subid) {
if let Some(ch) = channels.get(&cid) {
if ch.state == ChannelState::Connected {
if let Some(data_type) = rec.data_type {
let _ = transport_tx.send(TransportCommand::Unsubscribe {
sid: ch.sid,
subid,
data_type,
count: rec.count.unwrap_or(0),
server_addr: ch.server_addr.unwrap(),
priority: ch.priority,
});
}
}
}
}
if let Some(rec) = subscriptions.remove(subid) {
flow_control_note_consumed(
&mut flow_control,
(rec.server_addr, rec.priority),
rec.pending_deliveries,
&transport_tx,
);
}
}
if let Some(ch) = channels.get(&cid) {
if ch.state.is_operational() {
let _ = transport_tx.send(TransportCommand::ClearChannel {
cid,
sid: ch.sid,
server_addr: ch.server_addr.unwrap(),
priority: ch.priority,
});
}
match ch.state {
ChannelState::Searching
| ChannelState::Connecting
| ChannelState::Disconnected => {
let _ = search_tx.send(SearchRequest::Cancel { cid });
}
_ => {}
}
if let Some(addr) = ch.server_addr {
remove_server_channel(&mut server_channels, (addr, ch.priority), cid);
}
}
channels.remove(&cid);
cid_alloc.release(cid);
snapshots.remove(&cid);
let mut affected = HashSet::with_capacity(1);
affected.insert(cid);
drain_waiters_for_cids(&affected, &in_flight);
}
CoordRequest::GetWatchdogDelay { cid, reply } => {
let delay = channels.get(&cid).and_then(|ch| {
if !ch.state.is_operational() {
return None;
}
let addr = ch.server_addr?;
last_rx_at.get(&(addr, ch.priority)).map(|e| e.value().elapsed())
});
let _ = reply.send(delay);
}
CoordRequest::GetHostMinorProtocol { cid, reply } => {
let v = channels.get(&cid).and_then(|ch| {
if !ch.state.is_operational() {
return None;
}
let addr = ch.server_addr?;
server_minor_version.get(&(addr, ch.priority)).copied()
});
let _ = reply.send(v);
}
CoordRequest::GetIocConnectionCount { reply } => {
let states = channels
.values()
.map(|ch| (ch.state, ch.server_addr, ch.priority));
let _ = reply.send(operational_circuit_count(states));
}
CoordRequest::Shutdown { reply } => {
for ch in channels.values() {
if ch.state.is_operational() {
if let Some(addr) = ch.server_addr {
let _ = transport_tx.send(TransportCommand::ClearChannel {
cid: ch.cid,
sid: ch.sid,
server_addr: addr,
priority: ch.priority,
});
}
}
}
let _ = reply.send(());
return; }
CoordRequest::ForceRescanServer { server_addr, kind } => {
let is_real_restart = matches!(
kind,
beacon_monitor::BeaconAnomalyKind::IdMismatch
| beacon_monitor::BeaconAnomalyKind::PeriodCollapse
);
if is_real_restart {
diag.beacon_anomalies.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::BeaconAnomaly { server: server_addr });
tracing::warn!(
server = %server_addr,
?kind,
"beacon anomaly detected — IOC may have restarted"
);
metrics::counter!(
"ca_client_beacon_anomalies_total",
"server" => server_addr.to_string()
)
.increment(1);
} else {
tracing::debug!(
server = %server_addr,
"first sighting of beacon source — waking pending searches"
);
metrics::counter!(
"ca_client_beacon_first_sighting_total",
"server" => server_addr.to_string()
)
.increment(1);
}
for ch in channels.values() {
if ch.state == ChannelState::Disconnected
|| ch.state == ChannelState::Searching
{
let _ = search_tx.send(SearchRequest::Schedule {
cid: ch.cid,
pv_name: ch.pv_name.to_string(),
reason: SearchReason::BeaconAnomaly,
});
}
}
}
CoordRequest::BeaconArrival { server_addr, anomaly } => {
let states = channels.values().map(|ch| (ch.state, ch.server_addr));
for target in beacon_arrival_targets(states, server_addr) {
let _ = transport_tx.send(
TransportCommand::BeaconArrivalNotify {
server_addr: target,
anomaly,
},
);
}
}
}
}
resp = search_rx.recv() => {
let Some(resp) = resp else { return };
match resp {
SearchResponse::Found { cid, server_addr } => {
if let Some(ch) = channels.get_mut(&cid) {
if ch.state == ChannelState::Searching || ch.state == ChannelState::Disconnected {
let priority = ch.priority;
if let Some(old_addr) = ch.server_addr {
remove_server_channel(&mut server_channels, (old_addr, priority), cid);
}
ch.state = ChannelState::Connecting;
ch.server_addr = Some(server_addr);
server_channels.entry((server_addr, priority)).or_default().insert(cid);
let _ = transport_tx.send(TransportCommand::CreateChannel {
cid,
pv_name: ch.pv_name.to_string(),
server_addr,
priority,
});
}
} else {
pending_found.insert(cid, server_addr);
}
}
SearchResponse::MultiplyDefined {
pv_name,
prev_addr,
new_addr,
} => {
types::dispatch_exception(
&exception_slot,
types::CaException {
kind: types::CaExceptionKind::ServerError,
message: format!(
"Channel: \"{}\", Connecting to: {}, Ignored: {}",
pv_name, prev_addr, new_addr,
),
server_addr: Some(new_addr),
pv_name: Some(pv_name),
status: Some(crate::protocol::ECA_DBLCHNL),
},
);
}
}
}
evt = transport_rx.recv() => {
let Some(evt) = evt else { return };
match evt {
TransportEvent::ChannelCreated { cid, sid, data_type, element_count, access, server_addr, priority } => {
if let Some(ch) = channels.get_mut(&cid) {
let was_disconnected = matches!(ch.state, ChannelState::Disconnected);
let dbr_type = DbFieldType::from_u16(data_type).ok();
let previous_native = ch.native_type;
let native_changed = match (previous_native, dbr_type) {
(Some(p), Some(c)) => p != c,
(Some(_), None) | (None, Some(_)) => true,
(None, None) => false,
};
ch.state = ChannelState::Connected;
ch.sid = sid;
ch.native_type = dbr_type;
ch.element_count = element_count;
ch.server_addr = Some(server_addr);
ch.access_rights = access;
ch.last_connected_at = Some(std::time::Instant::now());
if let Some(dbr) = dbr_type {
snapshots.insert(
cid,
types::ChannelSnapshotPublic {
sid,
native_type: dbr,
element_count,
server_addr,
access_rights: access,
state: ChannelState::Connected,
},
);
} else {
snapshots.remove(&cid);
}
if was_disconnected {
tracing::info!(pv = %ch.pv_name, cid, sid, server = %server_addr, "channel reconnected");
} else {
tracing::info!(pv = %ch.pv_name, cid, sid, server = %server_addr, "channel connected");
}
metrics::counter!("ca_client_connections_total", "server" => server_addr.to_string()).increment(1);
metrics::gauge!("ca_client_channels_connected").increment(1.0);
search_attempts.remove(&cid);
for waiter in ch.connect_waiters.drain(..) {
let _ = waiter.send(());
}
let _ = ch.conn_tx.send(ConnectionEvent::Connected);
let _ = ch.conn_tx.send(ConnectionEvent::AccessRightsChanged {
read: access.read,
write: access.write,
});
if native_changed {
if let Some(cur) = dbr_type {
let _ = ch.conn_tx.send(ConnectionEvent::NativeTypeChanged {
previous: previous_native,
current: cur,
});
tracing::info!(
pv = %ch.pv_name,
previous = ?previous_native,
current = ?cur,
"channel native DBR type changed"
);
}
}
let (restored, stale) = subscriptions.restore_for_channel(
cid,
sid,
data_type,
element_count,
native_changed,
server_addr,
&transport_tx,
);
diag.connections.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::Connected { pv: ch.pv_name.to_string(), server: server_addr });
if restored > 0 || stale > 0 {
diag.reconnections.fetch_add(1, Ordering::Relaxed);
diag.subscriptions_restored.fetch_add(restored as u64, Ordering::Relaxed);
diag.subscriptions_stale.fetch_add(stale as u64, Ordering::Relaxed);
diag.record(DiagEvent::Reconnected { pv: ch.pv_name.to_string(), restored, stale });
tracing::debug!(
pv = %ch.pv_name,
restored,
stale,
"CA reconnect: subscriptions restored"
);
}
let _ = search_tx.send(SearchRequest::ConnectResult {
cid,
success: true,
server_addr,
});
} else {
tracing::debug!(
cid,
sid,
server = %server_addr,
"late CREATE_CHAN response for unknown CID; sending CLEAR_CHANNEL (libca parity)"
);
let _ = transport_tx.send(TransportCommand::ClearChannel {
cid,
sid,
server_addr,
priority,
});
}
}
TransportEvent::MonitorData { subid, data_type, count, data } => {
use subscription::MonitorDeliveryOutcome;
match subscriptions.on_monitor_data(subid, data_type, count, &data) {
MonitorDeliveryOutcome::Queued(circuit) => {
flow_control_note_queued(
&mut flow_control,
circuit,
&transport_tx,
);
}
MonitorDeliveryOutcome::Slotted(_circuit) => {
metrics::counter!(
"ca_client_coalesced_monitors_total"
)
.increment(1);
}
MonitorDeliveryOutcome::Dropped(_server_addr) => {
diag.dropped_monitors.fetch_add(1, Ordering::Relaxed);
tracing::warn!(
subid,
"monitor dropped (consumer handle closed)"
);
metrics::counter!("ca_client_dropped_monitors_total").increment(1);
}
MonitorDeliveryOutcome::Filtered
| MonitorDeliveryOutcome::NotFound => {}
}
}
TransportEvent::MonitorStatusError { subid, eca_status } => {
use subscription::MonitorDeliveryOutcome;
match subscriptions.on_monitor_error(subid, eca_status) {
MonitorDeliveryOutcome::Queued(circuit) => {
flow_control_note_queued(
&mut flow_control,
circuit,
&transport_tx,
);
}
MonitorDeliveryOutcome::Slotted(_) => {
metrics::counter!(
"ca_client_coalesced_monitors_total"
)
.increment(1);
}
MonitorDeliveryOutcome::Dropped(_) => {
diag.dropped_monitors.fetch_add(1, Ordering::Relaxed);
metrics::counter!(
"ca_client_monitor_error_drops_total"
)
.increment(1);
}
MonitorDeliveryOutcome::Filtered
| MonitorDeliveryOutcome::NotFound => {}
}
}
TransportEvent::AccessRightsChanged { cid, access } => {
if let Some(ch) = channels.get_mut(&cid) {
ch.access_rights = access;
let _ = ch.conn_tx.send(ConnectionEvent::AccessRightsChanged {
read: access.read,
write: access.write,
});
if let Some(mut snap) = snapshots.get_mut(&cid) {
snap.access_rights = access;
}
}
}
TransportEvent::ChannelCreateFailed { cid } => {
if let Some(ch) = channels.get_mut(&cid) {
let server_addr = ch.server_addr;
ch.state = ChannelState::Disconnected;
snapshots.remove(&cid);
let _ = ch.conn_tx.send(ConnectionEvent::Disconnected);
let _ = search_tx.send(SearchRequest::Schedule {
cid,
pv_name: ch.pv_name.to_string(),
reason: SearchReason::Reconnect,
});
if let Some(addr) = server_addr {
let _ = search_tx.send(SearchRequest::ConnectResult {
cid,
success: false,
server_addr: addr,
});
}
}
}
TransportEvent::ServerError {
eca_status,
original_request,
message,
server_addr,
} => {
let annotated = match original_request {
Some(cmd) => {
if message.is_empty() {
format!("(while processing cmd={cmd})")
} else {
format!("{message} (while processing cmd={cmd})")
}
}
None => message,
};
types::dispatch_exception(
&exception_slot,
types::CaException {
kind: types::CaExceptionKind::ServerError,
message: annotated,
server_addr: Some(server_addr),
pv_name: None,
status: Some(eca_status),
},
);
}
TransportEvent::TcpClosed { server_addr, priority } => {
let circuit = (server_addr, priority);
let n_affected = server_channels
.get(&circuit)
.map(|s| s.len())
.unwrap_or(0);
tracing::warn!(server = %server_addr, priority, channels = n_affected, "TCP circuit closed");
metrics::counter!("ca_client_tcp_closed_total", "server" => server_addr.to_string()).increment(1);
flow_control.remove(&circuit);
last_rx_at.remove(&circuit);
server_minor_version.remove(&circuit);
handle_disconnect(&mut channels, &mut subscriptions, &mut server_channels, &search_tx, circuit, &diag, &in_flight, &snapshots);
}
TransportEvent::ServerDisconnect { cid, server_addr } => {
if let Some(ch) = channels.get_mut(&cid) {
if ch.server_addr == Some(server_addr) {
ch.state = ChannelState::Disconnected;
snapshots.remove(&cid);
let _ = ch.conn_tx.send(ConnectionEvent::Disconnected);
let pv_name = ch.pv_name.to_string();
let cids = vec![cid];
let cleared = subscriptions.mark_disconnected(&cids);
for (circuit, count) in cleared {
flow_control_note_consumed(
&mut flow_control,
circuit,
count,
&transport_tx,
);
}
let mut affected = HashSet::with_capacity(1);
affected.insert(cid);
drain_waiters_for_cids(&affected, &in_flight);
let _ = search_tx.send(SearchRequest::Schedule {
cid,
pv_name: pv_name.clone(),
reason: SearchReason::Reconnect,
});
types::dispatch_exception(
&exception_slot,
types::CaException {
kind: types::CaExceptionKind::ServerDisconnect,
message: "server-initiated channel close".to_string(),
server_addr: Some(server_addr),
pv_name: Some(pv_name),
status: None,
},
);
}
}
}
TransportEvent::CircuitUnresponsive { server_addr, priority } => {
diag.unresponsive_events.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::Unresponsive { server: server_addr });
tracing::warn!(server = %server_addr, priority, "circuit unresponsive (echo timeout)");
metrics::counter!("ca_client_unresponsive_total", "server" => server_addr.to_string()).increment(1);
types::dispatch_exception(
&exception_slot,
types::CaException {
kind: types::CaExceptionKind::ServerError,
message: format!(
"circuit unresponsive: {server_addr} (matches libca ECA_UNRESPTMO)"
),
server_addr: Some(server_addr),
pv_name: None,
status: Some(crate::protocol::ECA_UNRESPTMO),
},
);
let mut affected_cids: Vec<u32> = Vec::new();
for ch in channels.values_mut() {
if ch.server_addr == Some(server_addr)
&& ch.priority == priority
&& ch.state == ChannelState::Connected
{
ch.state = ChannelState::Unresponsive;
if let Some(mut snap) = snapshots.get_mut(&ch.cid) {
snap.state = ChannelState::Unresponsive;
snap.access_rights = AccessRights { read: false, write: false };
}
ch.access_rights = AccessRights { read: false, write: false };
let _ = ch.conn_tx.send(ConnectionEvent::Unresponsive);
let _ = ch.conn_tx.send(ConnectionEvent::AccessRightsChanged {
read: false,
write: false,
});
affected_cids.push(ch.cid);
}
}
if !affected_cids.is_empty() {
let cid_set: HashSet<u32> = affected_cids.iter().copied().collect();
drain_waiters_for_cids(&cid_set, &in_flight);
let cleared = subscriptions.mark_disconnected(&affected_cids);
for (circuit, count) in cleared {
flow_control_note_consumed(
&mut flow_control,
circuit,
count,
&transport_tx,
);
}
}
}
TransportEvent::CircuitResponsive { server_addr, priority } => {
diag.record(DiagEvent::Responsive { server: server_addr });
tracing::info!(server = %server_addr, priority, "circuit responsive again");
let mut recovered_cids: Vec<u32> = Vec::new();
for ch in channels.values_mut() {
if ch.server_addr == Some(server_addr)
&& ch.priority == priority
&& ch.state == ChannelState::Unresponsive
{
ch.state = ChannelState::Connected;
if let Some(mut snap) = snapshots.get_mut(&ch.cid) {
snap.state = ChannelState::Connected;
}
let _ = ch.conn_tx.send(ConnectionEvent::Connected);
recovered_cids.push(ch.cid);
}
}
for cid in recovered_cids {
for sub_id in subscriptions.for_cid(cid) {
if let Some(rec) = subscriptions.get(sub_id) {
if let (Some(data_type), Some(_count)) =
(rec.data_type, rec.count)
{
if let Some(ch) = channels.get(&cid) {
if let Some(addr) = ch.server_addr {
let _ =
transport_tx.send(TransportCommand::ReadNotify {
sid: ch.sid,
data_type,
count: rec.count.unwrap_or(0),
ioid: in_flight.alloc_ioid(),
server_addr: addr,
priority: ch.priority,
});
}
}
}
}
}
}
}
TransportEvent::ServerVersion { server_addr, priority, minor_version } => {
server_minor_version.insert((server_addr, priority), minor_version);
}
TransportEvent::ServerConnected { server_addr } => {
let _ = beacon_ctrl_tx.send(
beacon_monitor::BeaconControl::ResetServer { server_addr },
);
}
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn handle_disconnect(
channels: &mut HashMap<u32, ChannelInner>,
subscriptions: &mut SubscriptionRegistry,
server_channels: &mut HashMap<types::CircuitKey, HashSet<u32>>,
search_tx: &mpsc::UnboundedSender<SearchRequest>,
circuit: types::CircuitKey,
diag: &CaDiagnostics,
in_flight: &types::InFlightOps,
snapshots: &ChannelSnapshots,
) {
let (server_addr, priority) = circuit;
let mut affected_cids = Vec::new();
let now = std::time::Instant::now();
for ch in channels.values_mut() {
if ch.server_addr == Some(server_addr)
&& ch.priority == priority
&& (ch.state.is_operational() || ch.state == ChannelState::Connecting)
{
ch.state = ChannelState::Disconnected;
snapshots.remove(&ch.cid);
affected_cids.push(ch.cid);
let _ = ch.conn_tx.send(ConnectionEvent::Disconnected);
let sustained = ch
.last_connected_at
.map(|t| now.duration_since(t).as_secs() > 30)
.unwrap_or(false);
if sustained {
ch.reconnect_count = 0;
} else {
ch.reconnect_count = ch.reconnect_count.saturating_add(1);
}
let _ = search_tx.send(SearchRequest::Schedule {
cid: ch.cid,
pv_name: ch.pv_name.to_string(),
reason: SearchReason::Reconnect,
});
}
}
if !affected_cids.is_empty() {
diag.disconnections.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::Disconnected {
server: server_addr,
channels: affected_cids.len(),
});
tracing::warn!(
server = %server_addr,
affected = affected_cids.len(),
"disconnect: scheduling reconnect for affected channels"
);
metrics::counter!("ca_client_disconnections_total", "server" => server_addr.to_string())
.increment(1);
metrics::gauge!("ca_client_channels_connected").decrement(affected_cids.len() as f64);
}
server_channels.remove(&circuit);
let _ = subscriptions.mark_disconnected(&affected_cids);
let affected: HashSet<u32> = affected_cids.into_iter().collect();
drain_waiters_for_cids(&affected, in_flight);
}
pub(crate) fn drain_waiters_for_cids(cids: &HashSet<u32>, in_flight: &types::InFlightOps) {
let stale_reads: Vec<u32> = in_flight
.reads
.iter()
.filter(|entry| cids.contains(&entry.value().cid()))
.map(|entry| *entry.key())
.collect();
for ioid in stale_reads {
if let Some((_, waiter)) = in_flight.reads.remove(&ioid) {
waiter.send(Err(CaError::Disconnected));
}
}
let stale_writes: Vec<u32> = in_flight
.writes
.iter()
.filter(|entry| cids.contains(&entry.value().0))
.map(|entry| *entry.key())
.collect();
for ioid in stale_writes {
if let Some((_, (_, sender))) = in_flight.writes.remove(&ioid) {
let _ = sender.send(Err(CaError::Disconnected));
}
}
}
fn beacon_arrival_targets<I>(channel_states: I, beacon_addr: SocketAddr) -> Vec<SocketAddr>
where
I: IntoIterator<Item = (ChannelState, Option<SocketAddr>)>,
{
let beacon_unspec = beacon_addr.ip().is_unspecified();
let mut found_exact = false;
let mut port_targets: HashSet<SocketAddr> = HashSet::new();
for (state, addr_opt) in channel_states {
if !state.is_operational() {
continue;
}
let Some(addr) = addr_opt else { continue };
if !beacon_unspec && addr == beacon_addr {
found_exact = true;
break;
}
if addr.port() == beacon_addr.port() {
port_targets.insert(addr);
}
}
if found_exact {
vec![beacon_addr]
} else {
port_targets.into_iter().collect()
}
}
fn operational_circuit_count<I>(channel_states: I) -> usize
where
I: IntoIterator<Item = (ChannelState, Option<SocketAddr>, u8)>,
{
let mut circuits = HashSet::<types::CircuitKey>::new();
for (state, addr_opt, priority) in channel_states {
if !state.is_operational() {
continue;
}
let Some(addr) = addr_opt else { continue };
circuits.insert((addr, priority));
}
circuits.len()
}
fn remove_server_channel(
server_channels: &mut HashMap<types::CircuitKey, HashSet<u32>>,
circuit: types::CircuitKey,
cid: u32,
) {
if let Some(set) = server_channels.get_mut(&circuit) {
set.remove(&cid);
if set.is_empty() {
server_channels.remove(&circuit);
}
}
}
fn resolve_host(host: &str, port: u16) -> CaResult<SocketAddr> {
if let Ok(ip) = host.parse::<Ipv4Addr>() {
return Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)));
}
use std::net::ToSocketAddrs;
let addr_str = format!("{host}:{port}");
let addrs: Vec<SocketAddr> = addr_str
.to_socket_addrs()
.map_err(|e| CaError::Protocol(format!("cannot resolve '{host}': {e}")))?
.collect();
addrs
.iter()
.find(|a| a.is_ipv4())
.or(addrs.first())
.copied()
.ok_or_else(|| CaError::Protocol(format!("no addresses for '{host}'")))
}
#[derive(Debug, Clone)]
pub(crate) struct AddrEntry {
pub sock: SocketAddr,
pub hostname: Option<String>,
pub port: u16,
}
impl AddrEntry {
pub fn new(sock: SocketAddr, hostname: Option<String>, port: u16) -> Self {
Self {
sock,
hostname,
port,
}
}
pub fn refresh_dns(&mut self) -> CaResult<SocketAddr> {
let Some(host) = self.hostname.as_deref() else {
return Ok(self.sock);
};
let new_sock = resolve_host(host, self.port)?;
self.sock = new_sock;
Ok(new_sock)
}
}
pub(crate) fn parse_addr_list_with_hostnames() -> CaResult<Vec<AddrEntry>> {
let mut addrs: Vec<AddrEntry> = Vec::new();
let default_port = epics_base_rs::runtime::env::get("EPICS_CA_SERVER_PORT")
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(CA_SERVER_PORT);
let ignore_ips = epics_rs_client_ignore();
if let Some(list) = epics_base_rs::runtime::env::get("EPICS_CA_ADDR_LIST") {
for entry in list.split_whitespace() {
let (host_raw, port) = if entry.contains(':') {
if let Some((h, p)) = entry.rsplit_once(':') {
let port: u16 = match p.parse() {
Ok(v) => v,
Err(_) => continue,
};
(h.to_string(), port)
} else {
(entry.to_string(), default_port)
}
} else {
(entry.to_string(), default_port)
};
let hostname = if host_raw.parse::<Ipv4Addr>().is_ok() {
None
} else {
Some(host_raw.clone())
};
match resolve_host(&host_raw, port) {
Ok(sock) => {
if let SocketAddr::V4(v4) = sock {
if ignore_ips.contains(v4.ip()) {
tracing::debug!(
target: "epics_ca_rs::client",
token = %entry,
ip = %v4.ip(),
"EPICS_RS_CLIENT_IGNORE: dropping ADDR_LIST entry"
);
continue;
}
}
addrs.push(AddrEntry::new(sock, hostname, port));
}
Err(e) => tracing::debug!(token = %entry, error = %e,
"EPICS_CA_ADDR_LIST: dropped unresolvable entry"),
}
}
}
let auto_addr = epics_base_rs::runtime::env::get_or("EPICS_CA_AUTO_ADDR_LIST", "YES");
let auto_addr_enabled = !(auto_addr.contains("no") || auto_addr.contains("NO"));
if auto_addr_enabled {
let server_port = default_port;
let bcasts = crate::server::addr_list::discover_broadcast_addrs();
append_auto_addr_entries(&mut addrs, &bcasts, server_port);
}
Ok(addrs)
}
fn append_auto_addr_entries(addrs: &mut Vec<AddrEntry>, bcasts: &[Ipv4Addr], server_port: u16) {
for bcast in bcasts {
let sock = SocketAddr::V4(SocketAddrV4::new(*bcast, server_port));
if !addrs.iter().any(|e| e.sock == sock) {
addrs.push(AddrEntry::new(sock, None, server_port));
}
}
if bcasts.is_empty() {
let loopback = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, server_port));
if !addrs.iter().any(|e| e.sock == loopback) {
addrs.push(AddrEntry::new(loopback, None, server_port));
}
}
let fallback = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::BROADCAST, server_port));
if !addrs.iter().any(|e| e.sock == fallback) {
addrs.push(AddrEntry::new(fallback, None, server_port));
}
}
#[cfg(test)]
mod enum_readback_tests {
use super::*;
#[test]
fn enum_string_when_opted_in() {
assert_eq!(
enum_string_readback_dbr(DbFieldType::Enum, true, true),
Some(DBR_TIME_STRING)
);
assert_eq!(
enum_string_readback_dbr(DbFieldType::Enum, false, true),
Some(DBR_STRING)
);
}
#[test]
fn native_when_opted_out_or_not_enum() {
assert_eq!(
enum_string_readback_dbr(DbFieldType::Enum, true, false),
None
);
assert_eq!(
enum_string_readback_dbr(DbFieldType::Enum, false, false),
None
);
assert_eq!(
enum_string_readback_dbr(DbFieldType::Double, true, true),
None
);
assert_eq!(
enum_string_readback_dbr(DbFieldType::Long, false, true),
None
);
}
#[test]
fn float_as_string_only_for_float_and_double() {
assert_eq!(
float_as_string_readback_dbr(DbFieldType::Float),
Some(DBR_TIME_STRING)
);
assert_eq!(
float_as_string_readback_dbr(DbFieldType::Double),
Some(DBR_TIME_STRING)
);
for t in [
DbFieldType::String,
DbFieldType::Short,
DbFieldType::Enum,
DbFieldType::Char,
DbFieldType::Long,
] {
assert_eq!(float_as_string_readback_dbr(t), None, "{t:?}");
}
}
#[test]
fn enum_cli_readback_substitutes_int_or_string() {
assert_eq!(
enum_cli_readback_dbr(DbFieldType::Enum, true),
Some(DBR_TIME_INT)
);
assert_eq!(
enum_cli_readback_dbr(DbFieldType::Enum, false),
Some(DBR_TIME_STRING)
);
for t in [
DbFieldType::String,
DbFieldType::Short,
DbFieldType::Float,
DbFieldType::Double,
DbFieldType::Char,
DbFieldType::Long,
] {
assert_eq!(enum_cli_readback_dbr(t, true), None, "{t:?} -n");
assert_eq!(enum_cli_readback_dbr(t, false), None, "{t:?} default");
}
}
#[test]
fn subscription_readback_chain_precedence() {
assert_eq!(
subscription_readback_dbr(DbFieldType::Enum, EnumReadback::Label, true),
DBR_TIME_STRING
);
assert_eq!(
subscription_readback_dbr(DbFieldType::Enum, EnumReadback::Numeric, true),
DBR_TIME_INT
);
assert_eq!(
subscription_readback_dbr(DbFieldType::Enum, EnumReadback::Native, true),
DbFieldType::Enum.time_dbr_type()
);
assert_eq!(
subscription_readback_dbr(DbFieldType::Float, EnumReadback::Native, true),
DBR_TIME_STRING
);
assert_eq!(
subscription_readback_dbr(DbFieldType::Double, EnumReadback::Native, true),
DBR_TIME_STRING
);
assert_eq!(
subscription_readback_dbr(DbFieldType::Double, EnumReadback::Native, false),
DbFieldType::Double.time_dbr_type()
);
assert_eq!(
subscription_readback_dbr(DbFieldType::Long, EnumReadback::Label, true),
DbFieldType::Long.time_dbr_type()
);
assert_eq!(
subscription_readback_dbr(DbFieldType::Long, EnumReadback::Numeric, false),
DbFieldType::Long.time_dbr_type()
);
}
}
#[cfg(test)]
mod ignore_servers_tests {
use super::*;
use serial_test::serial;
#[test]
#[serial(epics_env)]
fn empty_when_unset() {
unsafe { std::env::remove_var("EPICS_RS_CLIENT_IGNORE") };
assert!(epics_rs_client_ignore().is_empty());
}
#[test]
#[serial(epics_env)]
fn parses_ip_literals() {
unsafe { std::env::set_var("EPICS_RS_CLIENT_IGNORE", "10.0.0.1 192.168.1.42") };
let v = epics_rs_client_ignore();
assert_eq!(v.len(), 2);
assert!(v.contains(&Ipv4Addr::new(10, 0, 0, 1)));
assert!(v.contains(&Ipv4Addr::new(192, 168, 1, 42)));
unsafe { std::env::remove_var("EPICS_RS_CLIENT_IGNORE") };
}
#[test]
#[serial(epics_env)]
fn strips_port_suffix() {
unsafe { std::env::set_var("EPICS_RS_CLIENT_IGNORE", "10.0.0.1:5064") };
let v = epics_rs_client_ignore();
assert_eq!(v.len(), 1);
assert_eq!(v[0], Ipv4Addr::new(10, 0, 0, 1));
unsafe { std::env::remove_var("EPICS_RS_CLIENT_IGNORE") };
}
#[test]
#[serial(epics_env)]
fn silently_drops_garbage_entries() {
unsafe {
std::env::set_var(
"EPICS_RS_CLIENT_IGNORE",
"1.2.3.4 not-an-ip-or-host-that-resolves.invalid 5.6.7.8",
)
};
let v = epics_rs_client_ignore();
assert!(v.contains(&Ipv4Addr::new(1, 2, 3, 4)));
assert!(v.contains(&Ipv4Addr::new(5, 6, 7, 8)));
assert_eq!(v.len(), 2);
unsafe { std::env::remove_var("EPICS_RS_CLIENT_IGNORE") };
}
}
#[cfg(test)]
mod auto_addr_list_tests {
use super::*;
fn sock(ip: Ipv4Addr, port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(ip, port))
}
#[test]
fn loopback_added_when_no_broadcasts() {
let mut addrs: Vec<AddrEntry> = Vec::new();
append_auto_addr_entries(&mut addrs, &[], 5064);
assert!(
addrs
.iter()
.any(|e| e.sock == sock(Ipv4Addr::LOCALHOST, 5064))
);
assert!(
addrs
.iter()
.any(|e| e.sock == sock(Ipv4Addr::BROADCAST, 5064))
);
}
#[test]
fn loopback_not_added_when_broadcasts_present() {
let mut addrs: Vec<AddrEntry> = Vec::new();
let bcasts = vec![Ipv4Addr::new(192, 168, 1, 255)];
append_auto_addr_entries(&mut addrs, &bcasts, 5064);
assert!(
addrs
.iter()
.any(|e| e.sock == sock(Ipv4Addr::new(192, 168, 1, 255), 5064))
);
assert!(
!addrs
.iter()
.any(|e| e.sock == sock(Ipv4Addr::LOCALHOST, 5064))
);
assert!(
addrs
.iter()
.any(|e| e.sock == sock(Ipv4Addr::BROADCAST, 5064))
);
}
#[test]
fn does_not_duplicate_existing_entries() {
let mut addrs: Vec<AddrEntry> =
vec![AddrEntry::new(sock(Ipv4Addr::LOCALHOST, 5064), None, 5064)];
append_auto_addr_entries(&mut addrs, &[], 5064);
let count = addrs
.iter()
.filter(|e| e.sock == sock(Ipv4Addr::LOCALHOST, 5064))
.count();
assert_eq!(count, 1, "loopback must not be duplicated");
}
#[test]
fn respects_custom_server_port() {
let mut addrs: Vec<AddrEntry> = Vec::new();
append_auto_addr_entries(&mut addrs, &[], 5066);
assert!(
addrs
.iter()
.any(|e| e.sock == sock(Ipv4Addr::LOCALHOST, 5066))
);
assert!(
addrs
.iter()
.any(|e| e.sock == sock(Ipv4Addr::BROADCAST, 5066))
);
}
}
#[cfg(test)]
mod addr_entry_tests {
use super::*;
#[test]
fn ip_literal_has_no_hostname() {
let entry = AddrEntry::new(
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5064)),
None,
5064,
);
assert!(entry.hostname.is_none());
}
#[test]
fn refresh_noop_for_literal_ip() {
let original_sock = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5064));
let mut entry = AddrEntry::new(original_sock, None, 5064);
let refreshed = entry.refresh_dns().expect("noop refresh succeeds");
assert_eq!(refreshed, original_sock);
}
}
fn expand_pv_name(name: &str) -> String {
if epics_base_rs::runtime::env::get_or("EPICS_CA_USE_SHELL_VARS", "NO")
.eq_ignore_ascii_case("YES")
{
expand_shell_vars(name)
} else {
name.to_string()
}
}
fn expand_shell_vars(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'$' && i + 1 < bytes.len() {
let close = match bytes[i + 1] {
b'{' => Some(b'}'),
b'(' => Some(b')'),
_ => None,
};
if let Some(end) = close {
if let Some(j) = bytes[i + 2..].iter().position(|&b| b == end) {
let name = &s[i + 2..i + 2 + j];
let value = epics_base_rs::runtime::env::get(name).unwrap_or_default();
out.push_str(&value);
i += 3 + j;
continue;
}
}
}
out.push(s.as_bytes()[i] as char);
i += 1;
}
out
}
#[cfg(feature = "experimental-rust-tls")]
fn parse_tls_sni_map() -> Vec<(SocketAddr, String)> {
let Some(list) = epics_base_rs::runtime::env::get("EPICS_CA_TLS_SNI_MAP") else {
return Vec::new();
};
let mut out = Vec::new();
for entry in list.split_whitespace() {
let Some((addr_part, host)) = entry.split_once('=') else {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry missing '=', skipping");
continue;
};
if host.is_empty() {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry has empty hostname, skipping");
continue;
}
let addr = if addr_part.contains(':') {
match addr_part.parse::<SocketAddr>() {
Ok(a) => a,
Err(_) => {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry has unparseable IP:port, skipping");
continue;
}
}
} else {
match addr_part.parse::<std::net::IpAddr>() {
Ok(ip) => SocketAddr::new(ip, 0),
Err(_) => {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry has unparseable IP, skipping");
continue;
}
}
};
out.push((addr, host.to_string()));
}
out
}
pub(crate) fn parse_nameserver_list() -> Vec<(SocketAddr, Option<String>)> {
let Some(list) = epics_base_rs::runtime::env::get("EPICS_CA_NAME_SERVERS") else {
return Vec::new();
};
let default_server_port: u16 = epics_base_rs::runtime::env::get("EPICS_CA_SERVER_PORT")
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(CA_SERVER_PORT);
let mut out = Vec::new();
for entry in list.split_whitespace() {
if entry.contains(':') {
if let Ok(addr) = entry.parse::<SocketAddr>() {
out.push((addr, None));
continue;
}
let Some((host, port_str)) = entry.rsplit_once(':') else {
continue;
};
let Ok(port) = port_str.parse::<u16>() else {
continue;
};
if let Ok(addr) = resolve_host(host, port) {
let hostname = if host.parse::<std::net::IpAddr>().is_ok() {
None
} else {
Some(host.to_string())
};
out.push((addr, hostname));
}
} else {
if let Ok(addr) = resolve_host(entry, default_server_port) {
let hostname = if entry.parse::<std::net::IpAddr>().is_ok() {
None
} else {
Some(entry.to_string())
};
out.push((addr, hostname));
}
}
}
out
}
pub(crate) fn epics_rs_client_ignore() -> Vec<Ipv4Addr> {
let Some(list) = epics_base_rs::runtime::env::get("EPICS_RS_CLIENT_IGNORE") else {
return Vec::new();
};
let mut out = Vec::new();
for entry in list.split_whitespace() {
let host = entry.rsplit_once(':').map(|(h, _)| h).unwrap_or(entry);
if let Ok(ip) = host.parse::<Ipv4Addr>() {
out.push(ip);
continue;
}
use std::net::ToSocketAddrs;
if let Ok(mut iter) = format!("{host}:0").to_socket_addrs() {
if let Some(SocketAddr::V4(v4)) = iter.find(|sa: &SocketAddr| sa.is_ipv4()) {
out.push(*v4.ip());
} else {
tracing::debug!(
target: "epics_ca_rs::client",
entry = %entry,
"EPICS_RS_CLIENT_IGNORE: dropped unresolvable entry"
);
}
}
}
out
}
#[cfg(test)]
mod beacon_arrival_routing_tests {
use super::*;
fn addr(s: &str) -> SocketAddr {
s.parse().unwrap()
}
#[test]
fn exact_match_dominates() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.2:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert_eq!(targets, vec![addr("10.0.0.1:5064")]);
}
#[test]
fn unspecified_addr_falls_back_to_port_match() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.2:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.3:6000"))),
];
let mut targets = beacon_arrival_targets(states, addr("0.0.0.0:5064"));
targets.sort();
assert_eq!(
targets,
vec![addr("10.0.0.1:5064"), addr("10.0.0.2:5064")],
":6000 must NOT be a target for a :5064 beacon"
);
}
#[test]
fn multi_homed_falls_back_to_port_match() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.2:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert_eq!(
targets,
vec![addr("10.0.0.2:5064")],
"multi-homed IOC must be reachable via port-only fallback"
);
}
#[test]
fn non_operational_channels_do_not_match() {
let states = vec![
(ChannelState::Searching, Some(addr("10.0.0.1:5064"))),
(ChannelState::Disconnected, Some(addr("10.0.0.2:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert!(
targets.is_empty(),
"non-operational channels must not generate watchdog notifies"
);
}
#[test]
fn no_match_returns_empty() {
let states = vec![(ChannelState::Connected, Some(addr("10.0.0.1:5064")))];
let targets = beacon_arrival_targets(states, addr("10.0.0.99:5065"));
assert!(targets.is_empty());
}
#[test]
fn exact_match_emits_single_notify() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert_eq!(targets, vec![addr("10.0.0.1:5064")]);
}
}
#[cfg(test)]
mod ioc_connection_count_tests {
use super::*;
fn addr(s: &str) -> SocketAddr {
s.parse().unwrap()
}
#[test]
fn two_priorities_to_one_ioc_count_as_two_circuits() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064")), 0u8),
(ChannelState::Connected, Some(addr("10.0.0.1:5064")), 1u8),
];
assert_eq!(operational_circuit_count(states), 2);
}
#[test]
fn many_channels_one_circuit_count_as_one() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064")), 0u8),
(ChannelState::Connected, Some(addr("10.0.0.1:5064")), 0u8),
(ChannelState::Connected, Some(addr("10.0.0.1:5064")), 0u8),
];
assert_eq!(operational_circuit_count(states), 1);
}
#[test]
fn distinct_addresses_count_separately() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064")), 0u8),
(ChannelState::Connected, Some(addr("10.0.0.2:5064")), 0u8),
];
assert_eq!(operational_circuit_count(states), 2);
}
#[test]
fn non_operational_channels_excluded() {
let states = vec![
(ChannelState::Searching, Some(addr("10.0.0.1:5064")), 0u8),
(ChannelState::Disconnected, Some(addr("10.0.0.2:5064")), 1u8),
];
assert_eq!(operational_circuit_count(states), 0);
}
#[test]
fn missing_server_addr_excluded() {
let states = vec![(ChannelState::Connected, None, 0u8)];
assert_eq!(operational_circuit_count(states), 0);
}
}
#[cfg(test)]
mod tls_sni_config_tests {
#[cfg(feature = "experimental-rust-tls")]
use super::*;
#[cfg(feature = "experimental-rust-tls")]
#[test]
fn tls_server_name_round_trip() {
let mut cfg = CaClientConfig::default();
assert!(cfg.tls_server_name.is_none(), "default must be None");
cfg.tls_server_name = Some("ioc.example.com".into());
assert_eq!(cfg.tls_server_name.as_deref(), Some("ioc.example.com"));
}
}
#[cfg(test)]
mod waiter_drain_tests {
use super::*;
use std::collections::HashSet;
use tokio::sync::oneshot;
#[tokio::test(flavor = "current_thread")]
async fn drain_wakes_matching_cid_only() {
let in_flight = types::InFlightOps::new();
let (rtx_42, rrx_42) = oneshot::channel();
let (rtx_99, rrx_99) = oneshot::channel();
let (wtx_42, wrx_42) = oneshot::channel();
let (wtx_99, wrx_99) = oneshot::channel();
in_flight.reads.insert(
1001,
types::ReadWaiter::OneShot {
cid: 42,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx_42,
},
);
in_flight.reads.insert(
2001,
types::ReadWaiter::OneShot {
cid: 99,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx_99,
},
);
in_flight.writes.insert(1002, (42, wtx_42));
in_flight.writes.insert(2002, (99, wtx_99));
let mut affected = HashSet::new();
affected.insert(42u32);
drain_waiters_for_cids(&affected, &in_flight);
assert!(!in_flight.reads.contains_key(&1001));
assert!(!in_flight.writes.contains_key(&1002));
assert!(matches!(rrx_42.await, Ok(Err(CaError::Disconnected))));
assert!(matches!(wrx_42.await, Ok(Err(CaError::Disconnected))));
assert!(in_flight.reads.contains_key(&2001));
assert!(in_flight.writes.contains_key(&2002));
drop(in_flight);
assert!(rrx_99.await.is_err()); assert!(wrx_99.await.is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn drain_with_empty_cid_set_is_noop() {
let in_flight = types::InFlightOps::new();
let (rtx, rrx) = oneshot::channel();
let (wtx, wrx) = oneshot::channel();
in_flight.reads.insert(
10,
types::ReadWaiter::OneShot {
cid: 1,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx,
},
);
in_flight.writes.insert(20, (2, wtx));
let affected: HashSet<u32> = HashSet::new();
drain_waiters_for_cids(&affected, &in_flight);
assert!(in_flight.reads.contains_key(&10));
assert!(in_flight.writes.contains_key(&20));
drop(in_flight);
assert!(rrx.await.is_err());
assert!(wrx.await.is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn response_arrives_before_disconnect_drain() {
let in_flight = types::InFlightOps::new();
let (rtx, rrx) = oneshot::channel();
in_flight.reads.insert(
100,
types::ReadWaiter::OneShot {
cid: 7,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx,
},
);
if let Some((_, waiter)) = in_flight.reads.remove(&100) {
waiter.send(Ok(types::ReadReply::Raw {
data_type: 6,
count: 1,
data: vec![1, 0, 0, 0],
}));
}
let mut affected = HashSet::new();
affected.insert(7u32);
drain_waiters_for_cids(&affected, &in_flight);
let result = rrx.await.expect("oneshot still alive");
assert!(matches!(
result,
Ok(types::ReadReply::Raw {
data_type: 6,
count: 1,
..
})
));
}
}
#[cfg(test)]
mod event_watcher_tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[tokio::test(flavor = "current_thread")]
async fn drop_aborts_watcher_task() {
let ran = Arc::new(AtomicBool::new(false));
let ran_in_task = ran.clone();
let handle = epics_base_rs::runtime::task::spawn(async move {
ran_in_task.store(true, Ordering::SeqCst);
loop {
tokio::task::yield_now().await;
}
});
let abort_handle = handle.abort_handle();
let watcher = EventWatcher { handle };
tokio::task::yield_now().await;
assert!(
ran.load(Ordering::SeqCst),
"watcher task should have started"
);
assert!(
!abort_handle.is_finished(),
"task still running before drop"
);
drop(watcher);
for _ in 0..100 {
if abort_handle.is_finished() {
break;
}
tokio::task::yield_now().await;
}
assert!(
abort_handle.is_finished(),
"EventWatcher::drop must abort the watcher task"
);
}
#[tokio::test(flavor = "current_thread")]
async fn explicit_abort_stops_watcher_task() {
let handle = epics_base_rs::runtime::task::spawn(async move {
loop {
tokio::task::yield_now().await;
}
});
let abort_handle = handle.abort_handle();
let watcher = EventWatcher { handle };
watcher.abort();
for _ in 0..100 {
if abort_handle.is_finished() {
break;
}
tokio::task::yield_now().await;
}
assert!(
abort_handle.is_finished(),
"EventWatcher::abort must stop the watcher task"
);
}
}
#[cfg(test)]
mod typed_string_put_tests {
use super::*;
#[test]
fn put_string_frame_uses_dbr_string_type() {
let payload = EpicsValue::String("Running".into()).to_bytes();
let frame = CaChannel::build_write_frame(
CA_PROTO_WRITE,
77,
0,
1,
None,
payload,
);
let (hdr, _consumed) =
CaHeader::from_bytes_extended(&frame).expect("frame header must parse");
assert_eq!(hdr.cmmd, CA_PROTO_WRITE, "command must be CA_PROTO_WRITE");
assert_eq!(
hdr.data_type, 0,
"typed string-put must wire DBR_STRING (0), not the native type"
);
assert_eq!(hdr.cid, 77, "sid echoed in cid field");
}
#[test]
fn native_typed_put_frame_keeps_native_type() {
let payload = EpicsValue::Double(1.5).to_bytes();
let native_double = DbFieldType::Double as u16; let frame = CaChannel::build_write_frame(
CA_PROTO_WRITE_NOTIFY,
5,
native_double,
1,
Some(42),
payload,
);
let (hdr, _consumed) =
CaHeader::from_bytes_extended(&frame).expect("frame header must parse");
assert_eq!(hdr.data_type, native_double);
assert_ne!(
hdr.data_type, 0,
"native Double put must not collapse to DBR_STRING"
);
assert_eq!(hdr.available, 42, "ioid echoed in available field");
}
#[test]
fn put_string_array_frame_uses_dbr_string_type_and_count() {
let values = [
"Running".to_string(),
"Stopped".to_string(),
"Paused".to_string(),
];
let payload =
EpicsValue::StringArray(values.iter().map(|s| s.clone().into()).collect()).to_bytes();
assert_eq!(
payload.len(),
values.len() * 40,
"DBR_STRING array is 40 bytes per element"
);
let frame = CaChannel::build_write_frame(
CA_PROTO_WRITE_NOTIFY,
9,
0,
values.len() as u32,
Some(7),
payload,
);
let (hdr, _consumed) =
CaHeader::from_bytes_extended(&frame).expect("frame header must parse");
assert_eq!(
hdr.data_type, 0,
"string-array put must wire DBR_STRING (0)"
);
assert_eq!(hdr.count, 3, "count must be the element count");
}
}
#[cfg(test)]
mod user_data_tests {
use super::*;
#[test]
fn set_get_downcast_clone_share_and_clear() {
let (coord_tx, _rx) = mpsc::unbounded_channel();
let ch = CaChannel::for_test(coord_tx);
assert!(ch.user_data::<u32>().is_none());
ch.set_user_data(Arc::new(42u32));
assert_eq!(ch.user_data::<u32>().as_deref(), Some(&42u32));
assert!(ch.user_data::<String>().is_none());
let ch2 = ch.clone();
assert_eq!(ch2.user_data::<u32>().as_deref(), Some(&42u32));
assert!(ch.clear_user_data().is_some());
assert!(ch.user_data::<u32>().is_none());
assert!(ch2.user_data::<u32>().is_none());
}
}
#[cfg(test)]
mod monitor_pause_tests {
use super::*;
use epics_base_rs::server::snapshot::Snapshot;
use epics_base_rs::types::EpicsValue;
use std::time::{Duration, SystemTime};
fn snap(v: i32) -> Snapshot {
Snapshot::new(EpicsValue::Long(v), 0, 0, SystemTime::now())
}
#[tokio::test]
async fn pause_holds_value_resume_releases_after_backlog() {
let (callback_tx, callback_rx) = mpsc::channel::<CaResult<Snapshot>>(4);
let (coord_tx, _coord_rx) = mpsc::unbounded_channel();
let coalesce_slot = subscription::CoalesceSlot::new();
let mut handle = MonitorHandle {
subid: 1,
callback_rx,
coalesce_slot: coalesce_slot.clone(),
channel: CaChannel::for_test(coord_tx.clone()),
coord_tx,
};
callback_tx.try_send(Ok(snap(1))).expect("enqueue backlog");
handle.pause();
assert!(matches!(
coalesce_slot.route_value(snap(2)),
subscription::ValueRoute::Slotted
));
let v1 = tokio::time::timeout(Duration::from_millis(200), handle.recv())
.await
.expect("backlog should deliver promptly")
.expect("Some")
.expect("Ok");
assert_eq!(v1.value, EpicsValue::Long(1), "pre-pause backlog (1)");
let blocked = tokio::time::timeout(Duration::from_millis(150), handle.recv()).await;
assert!(
blocked.is_err(),
"held value must stay withheld while paused (recv-side gate)"
);
handle.resume();
let v2 = tokio::time::timeout(Duration::from_millis(200), handle.recv())
.await
.expect("held value should deliver after resume")
.expect("Some")
.expect("Ok");
assert_eq!(v2.value, EpicsValue::Long(2), "held latest after resume");
}
#[test]
fn monitor_queue_clamped_to_flow_control_threshold() {
assert_eq!(
resolve_monitor_queue_size(Some(5)),
FLOW_CONTROL_OFF_THRESHOLD
);
assert_eq!(
resolve_monitor_queue_size(Some(9)),
FLOW_CONTROL_OFF_THRESHOLD
);
assert_eq!(
resolve_monitor_queue_size(Some(FLOW_CONTROL_OFF_THRESHOLD)),
FLOW_CONTROL_OFF_THRESHOLD
);
assert_eq!(resolve_monitor_queue_size(Some(1000)), 1000);
assert_eq!(resolve_monitor_queue_size(None), 256);
}
#[tokio::test]
async fn pause_does_not_hide_errors() {
let (callback_tx, callback_rx) = mpsc::channel::<CaResult<Snapshot>>(4);
let (coord_tx, _coord_rx) = mpsc::unbounded_channel();
let coalesce_slot = subscription::CoalesceSlot::new();
let mut handle = MonitorHandle {
subid: 1,
callback_rx,
coalesce_slot,
channel: CaChannel::for_test(coord_tx.clone()),
coord_tx,
};
handle.pause();
callback_tx
.try_send(Err(CaError::ServerError(192))) .expect("enqueue error");
let got = tokio::time::timeout(Duration::from_millis(200), handle.recv())
.await
.expect("error must deliver while paused")
.expect("Some");
assert!(
matches!(got, Err(CaError::ServerError(192))),
"ECA_DISCONN must bypass pause"
);
}
}