mod beacon_monitor;
mod circuit_breaker;
mod search;
mod state;
mod subscription;
mod sync_group;
mod transport;
mod types;
pub use sync_group::{SyncGroup, SyncGroupResults};
pub use circuit_breaker::{BreakerConfig, BreakerState};
use std::collections::{HashMap, HashSet, VecDeque};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::time::Duration;
use epics_base_rs::runtime::sync::{broadcast, mpsc, oneshot};
use parking_lot::Mutex;
use crate::channel::{AccessRights, ChannelInfo, alloc_cid, alloc_ioid, alloc_subid};
use crate::protocol::*;
use crate::repeater;
use epics_base_rs::error::{CaError, CaResult};
use epics_base_rs::server::snapshot::{DbrClass, Snapshot};
use epics_base_rs::types::{DbFieldType, EpicsValue, decode_dbr};
pub use state::{ChannelState, ConnectionEvent};
use state::ChannelInner;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone)]
pub enum DiagEvent {
Connected {
pv: String,
server: SocketAddr,
},
Disconnected {
server: SocketAddr,
channels: usize,
},
Reconnected {
pv: String,
restored: u32,
stale: u32,
},
Unresponsive {
server: SocketAddr,
},
Responsive {
server: SocketAddr,
},
BeaconAnomaly {
server: SocketAddr,
},
}
impl std::fmt::Display for DiagEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connected { pv, server } => write!(f, "Connected {pv} @ {server}"),
Self::Disconnected { server, channels } => {
write!(f, "Disconnected {server} ({channels} channels)")
}
Self::Reconnected {
pv,
restored,
stale,
} => write!(f, "Reconnected {pv} (restored={restored}, stale={stale})"),
Self::Unresponsive { server } => write!(f, "Unresponsive {server}"),
Self::Responsive { server } => write!(f, "Responsive {server}"),
Self::BeaconAnomaly { server } => write!(f, "Beacon anomaly {server}"),
}
}
}
#[derive(Debug, Clone)]
pub struct DiagRecord {
pub time: std::time::Instant,
pub event: DiagEvent,
}
const EVENT_HISTORY_CAPACITY: usize = 256;
const ONE_SHOT_CHANNEL_CACHE_CAPACITY: usize = 4096;
#[derive(Default)]
struct OneShotChannelCache {
channels: HashMap<String, CaChannel>,
order: VecDeque<String>,
}
impl OneShotChannelCache {
fn get_or_create(&mut self, client: &CaClient, pv_name: String) -> CaChannel {
if let Some(channel) = self.channels.get(&pv_name) {
return channel.clone();
}
let channel = client.create_channel_expanded(pv_name.clone());
self.channels.insert(pv_name.clone(), channel.clone());
self.order.push_back(pv_name);
while self.channels.len() > ONE_SHOT_CHANNEL_CACHE_CAPACITY {
let Some(oldest) = self.order.pop_front() else {
break;
};
self.channels.remove(&oldest);
}
channel
}
}
pub struct CaDiagnostics {
pub connections: AtomicU64,
pub disconnections: AtomicU64,
pub reconnections: AtomicU64,
pub unresponsive_events: AtomicU64,
pub subscriptions_restored: AtomicU64,
pub subscriptions_stale: AtomicU64,
pub beacon_anomalies: AtomicU64,
pub search_requests: AtomicU64,
pub dropped_monitors: AtomicU64,
history: std::sync::Mutex<Vec<DiagRecord>>,
}
impl Default for CaDiagnostics {
fn default() -> Self {
Self {
connections: AtomicU64::new(0),
disconnections: AtomicU64::new(0),
reconnections: AtomicU64::new(0),
unresponsive_events: AtomicU64::new(0),
subscriptions_restored: AtomicU64::new(0),
subscriptions_stale: AtomicU64::new(0),
beacon_anomalies: AtomicU64::new(0),
search_requests: AtomicU64::new(0),
dropped_monitors: AtomicU64::new(0),
history: std::sync::Mutex::new(Vec::with_capacity(EVENT_HISTORY_CAPACITY)),
}
}
}
impl CaDiagnostics {
pub fn record(&self, event: DiagEvent) {
let record = DiagRecord {
time: std::time::Instant::now(),
event,
};
if let Ok(mut history) = self.history.lock() {
if history.len() >= EVENT_HISTORY_CAPACITY {
history.remove(0);
}
history.push(record);
}
}
pub fn snapshot(&self) -> DiagnosticsSnapshot {
let history = self.history.lock().map(|h| h.clone()).unwrap_or_default();
DiagnosticsSnapshot {
connections: self.connections.load(Ordering::Relaxed),
disconnections: self.disconnections.load(Ordering::Relaxed),
reconnections: self.reconnections.load(Ordering::Relaxed),
unresponsive_events: self.unresponsive_events.load(Ordering::Relaxed),
subscriptions_restored: self.subscriptions_restored.load(Ordering::Relaxed),
subscriptions_stale: self.subscriptions_stale.load(Ordering::Relaxed),
beacon_anomalies: self.beacon_anomalies.load(Ordering::Relaxed),
search_requests: self.search_requests.load(Ordering::Relaxed),
dropped_monitors: self.dropped_monitors.load(Ordering::Relaxed),
history,
}
}
}
#[derive(Debug, Clone)]
pub struct DiagnosticsSnapshot {
pub connections: u64,
pub disconnections: u64,
pub reconnections: u64,
pub unresponsive_events: u64,
pub subscriptions_restored: u64,
pub subscriptions_stale: u64,
pub beacon_anomalies: u64,
pub search_requests: u64,
pub dropped_monitors: u64,
pub history: Vec<DiagRecord>,
}
impl std::fmt::Display for DiagnosticsSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Connections: {}", self.connections)?;
writeln!(f, "Disconnections: {}", self.disconnections)?;
writeln!(f, "Reconnections: {}", self.reconnections)?;
writeln!(f, "Unresponsive events: {}", self.unresponsive_events)?;
writeln!(f, "Subscriptions restored: {}", self.subscriptions_restored)?;
writeln!(f, "Subscriptions stale: {}", self.subscriptions_stale)?;
writeln!(f, "Beacon anomalies: {}", self.beacon_anomalies)?;
writeln!(f, "Search requests: {}", self.search_requests)?;
writeln!(f, "Dropped monitors: {}", self.dropped_monitors)?;
if !self.history.is_empty() {
writeln!(f, "Recent events ({}):", self.history.len())?;
let start = self
.history
.first()
.map(|r| r.time)
.unwrap_or_else(std::time::Instant::now);
for rec in &self.history {
let elapsed = rec.time.duration_since(start);
writeln!(f, " +{:.1}s {}", elapsed.as_secs_f64(), rec.event)?;
}
}
Ok(())
}
}
use subscription::SubscriptionRegistry;
use types::*;
pub use types::{CaException, CaExceptionHandler, CaExceptionKind};
pub struct CaClient {
search_tx: mpsc::UnboundedSender<SearchRequest>,
transport_tx: mpsc::UnboundedSender<TransportCommand>,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
one_shot_channels: Mutex<OneShotChannelCache>,
in_flight: InFlightOps,
snapshots: ChannelSnapshots,
server_writers: DirectServerWriters,
diagnostics: Arc<CaDiagnostics>,
search_attempts: types::SearchAttempts,
exception_slot: types::CaExceptionSlot,
_coordinator: tokio::task::JoinHandle<()>,
_search_task: tokio::task::JoinHandle<()>,
_transport_task: tokio::task::JoinHandle<()>,
_beacon_task: tokio::task::JoinHandle<()>,
}
#[allow(dead_code)]
enum CoordRequest {
RegisterChannel {
cid: u32,
pv_name: String,
conn_tx: broadcast::Sender<ConnectionEvent>,
},
WaitConnected {
cid: u32,
reply: oneshot::Sender<()>,
},
Subscribe {
cid: u32,
subid: u32,
mask: u16,
deadband: f64,
callback_tx: mpsc::Sender<CaResult<Snapshot>>,
reply: oneshot::Sender<CaResult<()>>,
},
Unsubscribe {
subid: u32,
},
MonitorConsumed {
subid: u32,
},
DropChannel {
cid: u32,
},
ForceRescanServer {
server_addr: SocketAddr,
kind: beacon_monitor::BeaconAnomalyKind,
},
BeaconArrival {
server_addr: SocketAddr,
anomaly: bool,
},
GetWatchdogDelay {
cid: u32,
reply: oneshot::Sender<Option<Duration>>,
},
GetIocConnectionCount {
reply: oneshot::Sender<usize>,
},
GetHostMinorProtocol {
cid: u32,
reply: oneshot::Sender<Option<u16>>,
},
Shutdown {
reply: oneshot::Sender<()>,
},
}
#[derive(Default)]
pub struct CaClientConfig {
#[cfg(feature = "experimental-rust-tls")]
pub tls: Option<crate::tls::TlsConfig>,
#[cfg(feature = "experimental-rust-tls")]
pub tls_server_name: Option<String>,
pub discovery: Option<crate::discovery::DiscoveryConfig>,
pub extra_backends: Vec<Box<dyn crate::discovery::Backend>>,
}
impl CaClient {
pub async fn new() -> CaResult<Self> {
#[cfg(feature = "experimental-rust-tls")]
let cfg = {
let mut c = CaClientConfig::default();
match crate::tls::client_from_env() {
Ok(Some(tls)) => c.tls = Some(tls),
Ok(None) => {}
Err(e) => {
tracing::error!(error = %e,
"EPICS_CA_TLS_* configuration is invalid; using plaintext");
}
}
c.tls_server_name = epics_base_rs::runtime::env::get("EPICS_CA_TLS_SERVER_NAME");
c
};
#[cfg(not(feature = "experimental-rust-tls"))]
let cfg = CaClientConfig::default();
Self::new_with_config(cfg).await
}
pub async fn new_with_config(config: CaClientConfig) -> CaResult<Self> {
#[cfg(feature = "experimental-rust-tls")]
if config.tls.is_some() {
tracing::warn!(
"═══════════════════════════════════════════════════════════════════════\n \
CA client TLS ENABLED — non-standard, Rust-only extension.\n \
Cannot connect to C softIoc, EDM, MEDM, CSS, or pyepics-based tools.\n \
See doc/11-tls-design.md for rationale.\n \
═══════════════════════════════════════════════════════════════════════"
);
}
epics_base_rs::runtime::task::spawn(async { repeater::ensure_repeater().await });
let mut addr_list = parse_addr_list_with_hostnames()?;
let discovery_cfg = config.discovery.clone().or_else(crate::discovery::from_env);
let mut backends: Vec<Box<dyn crate::discovery::Backend>> = match discovery_cfg {
Some(cfg) => crate::discovery::build_backends(cfg),
None => Vec::new(),
};
backends.extend(config.extra_backends);
if !backends.is_empty() {
let mut discovered: Vec<SocketAddr> = Vec::new();
for b in &backends {
for addr in b.discover().await {
if !discovered.contains(&addr) {
discovered.push(addr);
}
}
}
if !discovered.is_empty() {
tracing::info!(
count = discovered.len(),
"discovered IOCs via service discovery: {:?}",
discovered
);
}
for addr in discovered {
if !addr_list.iter().any(|e| e.sock == addr) {
let port = match addr {
SocketAddr::V4(a) => a.port(),
SocketAddr::V6(a) => a.port(),
};
addr_list.push(AddrEntry::new(addr, None, port));
}
}
}
let nameserver_entries = parse_nameserver_list();
#[cfg(feature = "experimental-rust-tls")]
let sni_overrides: std::collections::HashMap<SocketAddr, String> = {
let mut map: std::collections::HashMap<SocketAddr, String> = nameserver_entries
.iter()
.filter_map(|(addr, host)| host.clone().map(|h| (*addr, h)))
.collect();
for (addr, host) in parse_tls_sni_map() {
map.insert(addr, host);
}
map
};
let nameserver_addrs: Vec<SocketAddr> =
nameserver_entries.iter().map(|(a, _)| *a).collect();
let (search_tx, search_rx) = mpsc::unbounded_channel();
let (search_resp_tx, search_resp_rx) = mpsc::unbounded_channel();
let (transport_tx, transport_rx) = mpsc::unbounded_channel();
let (transport_evt_tx, transport_evt_rx) = mpsc::unbounded_channel();
let (coord_tx, coord_rx) = mpsc::unbounded_channel();
let search_attempts: types::SearchAttempts = Arc::new(dashmap::DashMap::new());
let search_task = epics_base_rs::runtime::task::spawn(search::run_search_engine(
addr_list,
nameserver_addrs,
search_rx,
search_resp_tx,
search_attempts.clone(),
));
#[cfg(feature = "experimental-rust-tls")]
let tls_arc = config.tls.as_ref().and_then(|t| match t {
crate::tls::TlsConfig::Client(arc) => Some(arc.clone()),
crate::tls::TlsConfig::Server(_) => {
tracing::warn!("server-side TlsConfig passed to CaClient; ignoring");
None
}
});
let in_flight = InFlightOps::new();
let snapshots: ChannelSnapshots = Arc::new(dashmap::DashMap::new());
let server_writers: DirectServerWriters = Arc::new(dashmap::DashMap::new());
let last_rx_at: ServerLastRxAt = Arc::new(dashmap::DashMap::new());
let transport_task = {
#[cfg(feature = "experimental-rust-tls")]
{
epics_base_rs::runtime::task::spawn(transport::run_transport_manager(
transport_rx,
transport_evt_tx,
in_flight.clone(),
server_writers.clone(),
last_rx_at.clone(),
tls_arc,
config.tls_server_name.clone(),
sni_overrides,
))
}
#[cfg(not(feature = "experimental-rust-tls"))]
{
epics_base_rs::runtime::task::spawn(transport::run_transport_manager(
transport_rx,
transport_evt_tx,
in_flight.clone(),
server_writers.clone(),
last_rx_at.clone(),
))
}
};
let diagnostics = Arc::new(CaDiagnostics::default());
let exception_slot: types::CaExceptionSlot = Arc::new(parking_lot::RwLock::new(None));
let (beacon_ctrl_tx, beacon_ctrl_rx) =
mpsc::unbounded_channel::<beacon_monitor::BeaconControl>();
let coordinator = epics_base_rs::runtime::task::spawn(run_coordinator(
coord_rx,
search_resp_rx,
transport_evt_rx,
search_tx.clone(),
transport_tx.clone(),
in_flight.clone(),
snapshots.clone(),
last_rx_at,
diagnostics.clone(),
exception_slot.clone(),
search_attempts.clone(),
beacon_ctrl_tx,
));
let beacon_task = epics_base_rs::runtime::task::spawn(beacon_monitor::run_beacon_monitor(
coord_tx.clone(),
beacon_ctrl_rx,
));
Ok(Self {
search_tx,
transport_tx,
coord_tx,
one_shot_channels: Mutex::new(OneShotChannelCache::default()),
in_flight,
snapshots,
server_writers,
diagnostics,
search_attempts,
exception_slot,
_coordinator: coordinator,
_search_task: search_task,
_transport_task: transport_task,
_beacon_task: beacon_task,
})
}
pub fn diagnostics(&self) -> DiagnosticsSnapshot {
self.diagnostics.snapshot()
}
pub fn set_exception_handler<F>(&self, f: F) -> Option<types::CaExceptionHandler>
where
F: Fn(&types::CaException) + Send + Sync + 'static,
{
let new = Arc::new(f);
let mut slot = self.exception_slot.write();
slot.replace(new)
}
pub fn clear_exception_handler(&self) -> Option<types::CaExceptionHandler> {
self.exception_slot.write().take()
}
pub async fn ioc_connection_count(&self) -> usize {
let (tx, rx) = oneshot::channel();
if self
.coord_tx
.send(CoordRequest::GetIocConnectionCount { reply: tx })
.is_err()
{
return 0;
}
rx.await.unwrap_or(0)
}
pub async fn shutdown(&self) {
let (tx, rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::Shutdown { reply: tx });
let _ = tokio::time::timeout(Duration::from_secs(2), rx).await;
}
pub fn create_channel(&self, name: &str) -> CaChannel {
self.create_channel_expanded(expand_pv_name(name))
}
fn create_channel_expanded(&self, pv_name: String) -> CaChannel {
let cid = alloc_cid();
let (conn_tx, _) = broadcast::channel(16);
let _ = self.coord_tx.send(CoordRequest::RegisterChannel {
cid,
pv_name: pv_name.clone(),
conn_tx: conn_tx.clone(),
});
let _ = self.search_tx.send(SearchRequest::Schedule {
cid,
pv_name: pv_name.clone(),
reason: SearchReason::Initial,
});
let lifecycle = Arc::new(ChannelLifecycle {
cid,
coord_tx: self.coord_tx.clone(),
});
let channel_pv_name: Arc<str> = Arc::from(pv_name.as_str());
CaChannel {
cid,
pv_name: channel_pv_name,
coord_tx: self.coord_tx.clone(),
transport_tx: self.transport_tx.clone(),
in_flight: self.in_flight.clone(),
snapshots: self.snapshots.clone(),
server_writers: self.server_writers.clone(),
conn_tx,
cached_read: Arc::new(Mutex::new(None)),
search_attempts: self.search_attempts.clone(),
_lifecycle: lifecycle,
}
}
fn cached_one_shot_channel(&self, name: &str) -> CaChannel {
let pv_name = expand_pv_name(name);
self.one_shot_channels.lock().get_or_create(self, pv_name)
}
pub fn add_address(&self, addr: SocketAddr) {
let _ = self.search_tx.send(SearchRequest::AddAddress(addr));
}
pub fn set_address_list(&self, list: Vec<SocketAddr>) {
let _ = self.search_tx.send(SearchRequest::SetAddressList(list));
}
pub async fn caget(&self, pv_name: &str) -> CaResult<(DbFieldType, EpicsValue)> {
let ch = self.create_channel(pv_name);
ch.wait_connected(Duration::from_secs(3)).await?;
let result = ch.get().await;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
result
}
pub async fn caget_many<S>(&self, pv_names: &[S]) -> Vec<CaResult<(DbFieldType, EpicsValue)>>
where
S: AsRef<str> + Sync,
{
self.caget_many_with_timeout(pv_names, Duration::from_secs(30))
.await
}
pub async fn caget_many_with_timeout<S>(
&self,
pv_names: &[S],
timeout: Duration,
) -> Vec<CaResult<(DbFieldType, EpicsValue)>>
where
S: AsRef<str> + Sync,
{
let channels: Vec<CaChannel> = pv_names
.iter()
.map(|name| self.cached_one_shot_channel(name.as_ref()))
.collect();
let mut results = self.get_many_with_timeout(&channels, timeout).await;
let retry_indices: Vec<usize> = results
.iter()
.enumerate()
.filter_map(|(idx, result)| {
if matches!(
result,
Err(CaError::Disconnected) | Err(CaError::ChannelNotFound(_))
) {
Some(idx)
} else {
None
}
})
.collect();
if retry_indices.is_empty() {
return results;
}
let connected = futures_util::future::join_all(retry_indices.iter().map(|&idx| {
let channel = channels[idx].clone();
async move { (idx, channel.wait_connected(timeout).await) }
}))
.await;
let mut ready_indices = Vec::new();
let mut ready_channels = Vec::new();
for (idx, result) in connected {
match result {
Ok(()) => {
ready_indices.push(idx);
ready_channels.push(channels[idx].clone());
}
Err(e) => results[idx] = Err(e),
}
}
let read_results = self.get_many_with_timeout(&ready_channels, timeout).await;
for (idx, result) in ready_indices.into_iter().zip(read_results) {
results[idx] = result;
}
results
}
pub async fn get_many(
&self,
channels: &[CaChannel],
) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
self.get_many_with_timeout(channels, Duration::from_secs(30))
.await
}
pub async fn get_many_with_timeout(
&self,
channels: &[CaChannel],
timeout: Duration,
) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
CaChannel::get_many_with_timeout(channels, timeout).await
}
pub async fn caput(&self, pv_name: &str, value_str: &str) -> CaResult<()> {
let ch = self.create_channel(pv_name);
ch.wait_connected(Duration::from_secs(3)).await?;
let snap = ch.snapshot()?;
let value = EpicsValue::parse(snap.native_type, value_str)?;
ch.put_nowait(&value).await?;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
Ok(())
}
pub async fn caput_callback(
&self,
pv_name: &str,
value_str: &str,
timeout_secs: f64,
) -> CaResult<()> {
let ch = self.create_channel(pv_name);
let timeout = Duration::from_secs_f64(timeout_secs);
ch.wait_connected(timeout).await?;
let snap = ch.snapshot()?;
let value = EpicsValue::parse(snap.native_type, value_str)?;
ch.put_with_timeout(&value, timeout).await?;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
Ok(())
}
pub async fn cainfo(&self, pv_name: &str) -> CaResult<ChannelInfo> {
let ch = self.create_channel(pv_name);
ch.wait_connected(Duration::from_secs(3)).await?;
let info = ch.info().await;
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: ch.cid });
info
}
pub async fn camonitor<F>(&self, pv_name: &str, mut callback: F) -> CaResult<()>
where
F: FnMut(EpicsValue),
{
let ch = self.create_channel(pv_name);
let mut monitor = ch.subscribe().await?;
while let Some(result) = monitor.recv().await {
match result {
Ok(snap) => callback(snap.value),
Err(e) => return Err(e),
}
}
Ok(())
}
}
impl Drop for CaClient {
fn drop(&mut self) {
let coord_tx = self.coord_tx.clone();
let coord_abort = self._coordinator.abort_handle();
let search_abort = self._search_task.abort_handle();
let transport_abort = self._transport_task.abort_handle();
let beacon_abort = self._beacon_task.abort_handle();
if tokio::runtime::Handle::try_current().is_ok() {
tokio::spawn(async move {
let (tx, rx) = oneshot::channel();
if coord_tx.send(CoordRequest::Shutdown { reply: tx }).is_ok() {
let _ = tokio::time::timeout(Duration::from_secs(2), rx).await;
}
coord_abort.abort();
transport_abort.abort();
search_abort.abort();
beacon_abort.abort();
});
} else {
self._coordinator.abort();
self._transport_task.abort();
self._search_task.abort();
self._beacon_task.abort();
}
}
}
struct ChannelLifecycle {
cid: u32,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
}
impl Drop for ChannelLifecycle {
fn drop(&mut self) {
let _ = self
.coord_tx
.send(CoordRequest::DropChannel { cid: self.cid });
}
}
pub(crate) struct CachedRead {
pub(crate) ioid: u32,
pub(crate) sid: u32,
pub(crate) server_addr: SocketAddr,
pub(crate) data_type: u16,
pub(crate) element_count: u32,
pub(crate) slot: types::WarmReplySlot,
}
#[derive(Clone)]
pub struct CaChannel {
cid: u32,
pv_name: Arc<str>,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
transport_tx: mpsc::UnboundedSender<TransportCommand>,
in_flight: InFlightOps,
snapshots: ChannelSnapshots,
server_writers: DirectServerWriters,
conn_tx: broadcast::Sender<ConnectionEvent>,
cached_read: Arc<Mutex<Option<CachedRead>>>,
search_attempts: types::SearchAttempts,
_lifecycle: Arc<ChannelLifecycle>,
}
impl CaChannel {
pub async fn wait_connected(&self, timeout: Duration) -> CaResult<()> {
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::WaitConnected {
cid: self.cid,
reply: reply_tx,
});
tokio::time::timeout(timeout, reply_rx)
.await
.map_err(|_| CaError::ChannelNotFound(self.pv_name.to_string()))?
.map_err(|_| CaError::Shutdown)
}
pub async fn info(&self) -> CaResult<ChannelInfo> {
let snap = self.snapshot()?;
Ok(ChannelInfo {
pv_name: self.pv_name.to_string(),
server_addr: snap.server_addr,
native_type: snap.native_type,
element_count: snap.element_count,
access_rights: snap.access_rights,
})
}
pub fn search_attempts(&self) -> u32 {
self.search_attempts
.get(&self.cid)
.map(|e| e.load(std::sync::atomic::Ordering::Relaxed))
.unwrap_or(0)
}
fn snapshot(&self) -> CaResult<ChannelSnapshotPublic> {
match self.snapshots.get(&self.cid) {
Some(s) if s.state.is_operational() => Ok(s.clone()),
_ => Err(CaError::Disconnected),
}
}
fn direct_writer(&self, server_addr: SocketAddr) -> Option<DirectServerWriter> {
self.server_writers.get(&server_addr).map(|w| w.clone())
}
fn build_read_notify_frame(sid: u32, data_type: u16, count: u32, ioid: u32) -> Vec<u8> {
let mut hdr = CaHeader::new(CA_PROTO_READ_NOTIFY);
hdr.data_type = data_type;
hdr.cid = sid;
hdr.available = ioid;
if count > 0xFFFF {
hdr.set_payload_size(0, count);
} else {
hdr.count = count as u16;
}
hdr.to_bytes_extended()
}
fn decode_plain_read_reply(reply: ReadReply) -> CaResult<(DbFieldType, EpicsValue)> {
match reply {
ReadReply::Plain { dbr_type, value } => Ok((dbr_type, value)),
ReadReply::Raw {
data_type,
count,
data,
} => {
let dbr_type = DbFieldType::from_u16(data_type)?;
EpicsValue::from_bytes_array(dbr_type, &data, count as usize)
.map(|value| (dbr_type, value))
}
}
}
fn build_write_frame(
cmd: u16,
sid: u32,
data_type: u16,
count: u32,
ioid: Option<u32>,
payload: Vec<u8>,
) -> Vec<u8> {
let padded_len = align8(payload.len());
let mut padded = payload;
padded.resize(padded_len, 0);
let mut hdr = CaHeader::new(cmd);
hdr.data_type = data_type;
hdr.cid = sid;
if let Some(ioid) = ioid {
hdr.available = ioid;
}
hdr.set_payload_size(padded.len(), count);
let mut frame = hdr.to_bytes_extended();
frame.extend_from_slice(&padded);
frame
}
fn send_read_notify_fast(
&self,
snap: &ChannelSnapshotPublic,
data_type: u16,
count: u32,
ioid: u32,
) -> CaResult<()> {
if let Some(writer) = self.direct_writer(snap.server_addr) {
return writer.send_frame(Self::build_read_notify_frame(
snap.sid, data_type, count, ioid,
));
}
self.transport_tx
.send(TransportCommand::ReadNotify {
sid: snap.sid,
data_type,
count,
ioid,
server_addr: snap.server_addr,
})
.map_err(|_| CaError::Shutdown)
}
fn send_write_notify_fast(
&self,
snap: &ChannelSnapshotPublic,
count: u32,
ioid: u32,
payload: Vec<u8>,
) -> CaResult<()> {
if let Some(writer) = self.direct_writer(snap.server_addr) {
return writer.send_frame(Self::build_write_frame(
CA_PROTO_WRITE_NOTIFY,
snap.sid,
snap.native_type as u16,
count,
Some(ioid),
payload,
));
}
self.transport_tx
.send(TransportCommand::WriteNotify {
sid: snap.sid,
data_type: snap.native_type as u16,
count,
ioid,
payload,
server_addr: snap.server_addr,
})
.map_err(|_| CaError::Shutdown)
}
fn send_write_nowait_fast(
&self,
snap: &ChannelSnapshotPublic,
count: u32,
payload: Vec<u8>,
) -> CaResult<()> {
if let Some(writer) = self.direct_writer(snap.server_addr) {
return writer.send_frame(Self::build_write_frame(
CA_PROTO_WRITE,
snap.sid,
snap.native_type as u16,
count,
None,
payload,
));
}
self.transport_tx
.send(TransportCommand::Write {
sid: snap.sid,
data_type: snap.native_type as u16,
count,
payload,
server_addr: snap.server_addr,
})
.map_err(|_| CaError::Shutdown)
}
pub async fn get(&self) -> CaResult<(DbFieldType, EpicsValue)> {
self.get_with_timeout(Duration::from_secs(30)).await
}
pub async fn get_many(channels: &[CaChannel]) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
Self::get_many_with_timeout(channels, Duration::from_secs(30)).await
}
pub async fn get_many_with_timeout(
channels: &[CaChannel],
timeout: Duration,
) -> Vec<CaResult<(DbFieldType, EpicsValue)>> {
enum PendingKind {
Cold {
ioid: u32,
in_flight: InFlightOps,
cid: u32,
cached_read_slot: Arc<Mutex<Option<CachedRead>>>,
sid: u32,
server_addr: SocketAddr,
data_type: u16,
element_count: u32,
},
Warm {
ioid: u32,
in_flight: InFlightOps,
cached_read_slot: Arc<Mutex<Option<CachedRead>>>,
cached: CachedRead,
},
}
struct Pending {
index: usize,
reply_rx: oneshot::Receiver<CaResult<ReadReply>>,
kind: PendingKind,
}
struct BulkReadGroup {
writer: DirectServerWriter,
frame: Vec<u8>,
pending: Vec<Pending>,
}
let mut results: Vec<Option<CaResult<(DbFieldType, EpicsValue)>>> =
(0..channels.len()).map(|_| None).collect();
let mut groups: HashMap<SocketAddr, BulkReadGroup> = HashMap::new();
let mut pending: Vec<Pending> = Vec::new();
for (index, ch) in channels.iter().enumerate() {
let snap = match ch.snapshot() {
Ok(s) => s,
Err(e) => {
results[index] = Some(Err(e));
continue;
}
};
let warm_taken: Option<CachedRead> = {
let mut guard = ch.cached_read.lock();
let matches = matches!(guard.as_ref(), Some(c)
if c.server_addr == snap.server_addr
&& c.sid == snap.sid
&& c.data_type == snap.native_type as u16
&& c.element_count == snap.element_count);
if matches {
guard.take()
} else if guard.is_some() {
let stale = guard.take().unwrap();
ch.in_flight.reads.remove(&stale.ioid);
None
} else {
None
}
};
let (reply_tx, reply_rx) = oneshot::channel();
let (frame, kind) = if let Some(cached) = warm_taken {
*cached.slot.lock() = Some(reply_tx);
let frame = Self::build_read_notify_frame(
cached.sid,
cached.data_type,
cached.element_count,
cached.ioid,
);
let kind = PendingKind::Warm {
ioid: cached.ioid,
in_flight: ch.in_flight.clone(),
cached_read_slot: ch.cached_read.clone(),
cached,
};
(frame, kind)
} else {
let ioid = alloc_ioid();
ch.in_flight.reads.insert(
ioid,
ReadWaiter::OneShot {
cid: ch.cid,
mode: ReadReplyMode::Plain,
reply_tx,
},
);
let frame = Self::build_read_notify_frame(
snap.sid,
snap.native_type as u16,
snap.element_count,
ioid,
);
let kind = PendingKind::Cold {
ioid,
in_flight: ch.in_flight.clone(),
cid: ch.cid,
cached_read_slot: ch.cached_read.clone(),
sid: snap.sid,
server_addr: snap.server_addr,
data_type: snap.native_type as u16,
element_count: snap.element_count,
};
(frame, kind)
};
let pending_read = Pending {
index,
reply_rx,
kind,
};
if let Some(writer) = ch.direct_writer(snap.server_addr) {
let group = groups
.entry(snap.server_addr)
.or_insert_with(|| BulkReadGroup {
writer,
frame: Vec::new(),
pending: Vec::new(),
});
group.frame.extend_from_slice(&frame);
group.pending.push(pending_read);
} else {
match pending_read.kind {
PendingKind::Cold {
ioid, in_flight, ..
} => match ch.send_read_notify_fast(
&snap,
snap.native_type as u16,
snap.element_count,
ioid,
) {
Ok(()) => pending.push(Pending {
index,
reply_rx: pending_read.reply_rx,
kind: PendingKind::Cold {
ioid,
in_flight,
cid: ch.cid,
cached_read_slot: ch.cached_read.clone(),
sid: snap.sid,
server_addr: snap.server_addr,
data_type: snap.native_type as u16,
element_count: snap.element_count,
},
}),
Err(e) => {
in_flight.reads.remove(&ioid);
results[index] = Some(Err(e));
}
},
PendingKind::Warm {
ioid,
in_flight,
cached_read_slot,
..
} => {
in_flight.reads.remove(&ioid);
*cached_read_slot.lock() = None;
results[index] = Some(Err(CaError::Disconnected));
}
}
}
}
for (_, group) in groups {
match group.writer.send_frame(group.frame) {
Ok(()) => pending.extend(group.pending),
Err(_) => {
for p in group.pending {
match p.kind {
PendingKind::Cold {
ioid, in_flight, ..
} => {
in_flight.reads.remove(&ioid);
}
PendingKind::Warm {
ioid,
in_flight,
cached_read_slot,
..
} => {
in_flight.reads.remove(&ioid);
*cached_read_slot.lock() = None;
}
}
results[p.index] = Some(Err(CaError::Disconnected));
}
}
}
}
let deadline = tokio::time::Instant::now() + timeout;
for p in pending {
let Pending {
index,
reply_rx,
kind,
} = p;
let result = tokio::time::timeout_at(deadline, reply_rx).await;
let decoded: CaResult<(DbFieldType, EpicsValue)> = match result {
Ok(Ok(Ok(reply))) => Self::decode_plain_read_reply(reply),
Ok(Ok(Err(e))) => Err(e),
Ok(Err(_)) => Err(CaError::Shutdown),
Err(_) => Err(CaError::Timeout),
};
let is_local_error = matches!(decoded, Err(CaError::Timeout) | Err(CaError::Shutdown));
match kind {
PendingKind::Cold {
ioid,
in_flight,
cid,
cached_read_slot,
sid,
server_addr,
data_type,
element_count,
} => {
if is_local_error {
in_flight.reads.remove(&ioid);
}
if decoded.is_ok() {
let warm_ioid = alloc_ioid();
let slot: types::WarmReplySlot = Arc::new(parking_lot::Mutex::new(None));
in_flight.reads.insert(
warm_ioid,
ReadWaiter::Warm {
cid,
mode: ReadReplyMode::Plain,
slot: slot.clone(),
},
);
let cached = CachedRead {
ioid: warm_ioid,
sid,
server_addr,
data_type,
element_count,
slot,
};
let mut guard = cached_read_slot.lock();
if guard.is_none() {
*guard = Some(cached);
} else {
drop(guard);
in_flight.reads.remove(&warm_ioid);
}
}
}
PendingKind::Warm {
ioid,
in_flight,
cached_read_slot,
cached,
} => {
if is_local_error {
in_flight.reads.remove(&ioid);
drop(cached);
*cached_read_slot.lock() = None;
} else {
let mut guard = cached_read_slot.lock();
if guard.is_none() {
*guard = Some(cached);
} else {
drop(guard);
in_flight.reads.remove(&ioid);
}
}
}
}
results[index] = Some(decoded);
}
results
.into_iter()
.map(|r| r.unwrap_or(Err(CaError::Shutdown)))
.collect()
}
pub async fn get_with_timeout(&self, timeout: Duration) -> CaResult<(DbFieldType, EpicsValue)> {
let snap = self.snapshot()?;
let ioid = alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.reads.insert(
ioid,
ReadWaiter::OneShot {
cid: self.cid,
mode: ReadReplyMode::Plain,
reply_tx,
},
);
if let Err(e) =
self.send_read_notify_fast(&snap, snap.native_type as u16, snap.element_count, ioid)
{
self.in_flight.reads.remove(&ioid);
return Err(e);
}
let result = tokio::time::timeout(timeout, reply_rx).await;
self.in_flight.reads.remove(&ioid);
let reply = result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)??;
Self::decode_plain_read_reply(reply)
}
pub async fn get_with_metadata(&self, class: DbrClass) -> CaResult<Snapshot> {
self.get_with_metadata_count(class, 0).await
}
pub async fn get_with_metadata_count(&self, class: DbrClass, count: u32) -> CaResult<Snapshot> {
let snap = self.snapshot()?;
let request_count = if count > 0 {
count.min(snap.element_count)
} else {
snap.element_count
};
let native = DbFieldType::from_u16(snap.native_type as u16)?;
let request_type = match class {
DbrClass::Time => native.time_dbr_type(),
DbrClass::Ctrl => native.ctrl_dbr_type(),
DbrClass::Sts => native as u16 + 7,
DbrClass::Gr => native as u16 + 21,
DbrClass::Plain => native as u16,
};
let ioid = alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.reads.insert(
ioid,
ReadWaiter::OneShot {
cid: self.cid,
mode: ReadReplyMode::Raw,
reply_tx,
},
);
if let Err(e) = self.send_read_notify_fast(&snap, request_type, request_count, ioid) {
self.in_flight.reads.remove(&ioid);
return Err(e);
}
let result = tokio::time::timeout(Duration::from_secs(30), reply_rx).await;
self.in_flight.reads.remove(&ioid);
let reply = result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)??;
match reply {
ReadReply::Raw {
data_type,
count,
data,
} => decode_dbr(data_type, &data, count as usize),
ReadReply::Plain { .. } => Err(CaError::Protocol(
"metadata read returned a plain scalar reply".into(),
)),
}
}
pub async fn put(&self, value: &EpicsValue) -> CaResult<()> {
let snap = self.snapshot()?;
let ioid = alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.writes.insert(ioid, (self.cid, reply_tx));
let payload = value.to_bytes();
let count = value.count() as u32;
if let Err(e) = self.send_write_notify_fast(&snap, count, ioid, payload) {
self.in_flight.writes.remove(&ioid);
return Err(e);
}
let default_secs = epics_base_rs::runtime::env::get("EPICS_CA_PUT_TIMEOUT")
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(30.0);
let result = tokio::time::timeout(Duration::from_secs_f64(default_secs), reply_rx).await;
self.in_flight.writes.remove(&ioid);
result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)?
}
pub async fn put_with_timeout(&self, value: &EpicsValue, timeout: Duration) -> CaResult<()> {
let snap = self.snapshot()?;
let ioid = alloc_ioid();
let (reply_tx, reply_rx) = oneshot::channel();
self.in_flight.writes.insert(ioid, (self.cid, reply_tx));
let payload = value.to_bytes();
let count = value.count() as u32;
if let Err(e) = self.send_write_notify_fast(&snap, count, ioid, payload) {
self.in_flight.writes.remove(&ioid);
return Err(e);
}
let result = tokio::time::timeout(timeout, reply_rx).await;
self.in_flight.writes.remove(&ioid);
result
.map_err(|_| CaError::Timeout)?
.map_err(|_| CaError::Shutdown)?
}
pub async fn put_nowait(&self, value: &EpicsValue) -> CaResult<()> {
let snap = self.snapshot()?;
let payload = value.to_bytes();
let count = value.count() as u32;
self.send_write_nowait_fast(&snap, count, payload)
}
pub async fn subscribe(&self) -> CaResult<MonitorHandle> {
self.subscribe_with_deadband(0.0).await
}
pub async fn subscribe_with_deadband(&self, deadband: f64) -> CaResult<MonitorHandle> {
let subid = alloc_subid();
let queue_size = epics_base_rs::runtime::env::get("EPICS_CA_MONITOR_QUEUE")
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(256)
.max(8);
let (callback_tx, callback_rx) = mpsc::channel(queue_size);
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::Subscribe {
cid: self.cid,
subid,
mask: DBE_VALUE | DBE_LOG | DBE_ALARM,
deadband,
callback_tx,
reply: reply_tx,
});
reply_rx.await.map_err(|_| CaError::Shutdown)??;
Ok(MonitorHandle {
subid,
callback_rx,
coord_tx: self.coord_tx.clone(),
})
}
pub fn connection_events(&self) -> broadcast::Receiver<ConnectionEvent> {
self.conn_tx.subscribe()
}
pub fn on_access_rights_change<F>(&self, mut cb: F) -> tokio::task::JoinHandle<()>
where
F: FnMut(AccessRights) + Send + 'static,
{
let mut rx = self.conn_tx.subscribe();
epics_base_rs::runtime::task::spawn(async move {
while let Ok(evt) = rx.recv().await {
if let ConnectionEvent::AccessRightsChanged { read, write } = evt {
cb(AccessRights { read, write });
}
}
})
}
pub fn on_connection_change<F>(&self, mut cb: F) -> tokio::task::JoinHandle<()>
where
F: FnMut(bool) + Send + 'static,
{
let mut rx = self.conn_tx.subscribe();
epics_base_rs::runtime::task::spawn(async move {
while let Ok(evt) = rx.recv().await {
match evt {
ConnectionEvent::Connected => cb(true),
ConnectionEvent::Disconnected => cb(false),
_ => {}
}
}
})
}
pub async fn host_name(&self) -> CaResult<String> {
let info = self.info().await?;
Ok(info.server_addr.to_string())
}
pub async fn host_minor_protocol(&self) -> Option<u16> {
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::GetHostMinorProtocol {
cid: self.cid,
reply: reply_tx,
});
reply_rx.await.ok().flatten()
}
pub async fn receive_watchdog_delay(&self) -> Duration {
let (reply_tx, reply_rx) = oneshot::channel();
let _ = self.coord_tx.send(CoordRequest::GetWatchdogDelay {
cid: self.cid,
reply: reply_tx,
});
match reply_rx.await {
Ok(Some(d)) => d,
_ => Duration::ZERO,
}
}
}
pub struct MonitorHandle {
subid: u32,
callback_rx: mpsc::Receiver<CaResult<Snapshot>>,
coord_tx: mpsc::UnboundedSender<CoordRequest>,
}
impl MonitorHandle {
pub async fn recv(&mut self) -> Option<CaResult<Snapshot>> {
let result = self.callback_rx.recv().await;
if result.is_some() {
let _ = self
.coord_tx
.send(CoordRequest::MonitorConsumed { subid: self.subid });
}
result
}
}
impl Drop for MonitorHandle {
fn drop(&mut self) {
let _ = self
.coord_tx
.send(CoordRequest::Unsubscribe { subid: self.subid });
}
}
const FLOW_CONTROL_OFF_THRESHOLD: usize = 10;
const FLOW_CONTROL_ON_THRESHOLD: usize = 5;
#[derive(Default)]
struct FlowControlState {
outstanding: usize,
active: bool,
}
fn flow_control_note_queued(
flow_control: &mut HashMap<SocketAddr, FlowControlState>,
server_addr: SocketAddr,
transport_tx: &mpsc::UnboundedSender<TransportCommand>,
) {
let state = flow_control.entry(server_addr).or_default();
state.outstanding = state.outstanding.saturating_add(1);
if !state.active && state.outstanding >= FLOW_CONTROL_OFF_THRESHOLD {
let _ = transport_tx.send(TransportCommand::EventsOff { server_addr });
state.active = true;
}
}
fn flow_control_note_consumed(
flow_control: &mut HashMap<SocketAddr, FlowControlState>,
server_addr: SocketAddr,
count: usize,
transport_tx: &mpsc::UnboundedSender<TransportCommand>,
) {
if count == 0 {
return;
}
let Some(state) = flow_control.get_mut(&server_addr) else {
return;
};
state.outstanding = state.outstanding.saturating_sub(count);
if state.active && state.outstanding <= FLOW_CONTROL_ON_THRESHOLD {
let _ = transport_tx.send(TransportCommand::EventsOn { server_addr });
state.active = false;
}
if !state.active && state.outstanding == 0 {
flow_control.remove(&server_addr);
}
}
#[allow(clippy::too_many_arguments)]
async fn run_coordinator(
mut coord_rx: mpsc::UnboundedReceiver<CoordRequest>,
mut search_rx: mpsc::UnboundedReceiver<SearchResponse>,
mut transport_rx: mpsc::UnboundedReceiver<TransportEvent>,
search_tx: mpsc::UnboundedSender<SearchRequest>,
transport_tx: mpsc::UnboundedSender<TransportCommand>,
in_flight: types::InFlightOps,
snapshots: ChannelSnapshots,
last_rx_at: ServerLastRxAt,
diag: Arc<CaDiagnostics>,
exception_slot: types::CaExceptionSlot,
search_attempts: types::SearchAttempts,
beacon_ctrl_tx: mpsc::UnboundedSender<beacon_monitor::BeaconControl>,
) {
let mut channels: HashMap<u32, ChannelInner> = HashMap::new();
let mut pending_wait_connected: HashMap<u32, Vec<oneshot::Sender<()>>> = HashMap::new();
let mut pending_found: HashMap<u32, SocketAddr> = HashMap::new();
let mut subscriptions = SubscriptionRegistry::new();
let mut server_channels: HashMap<SocketAddr, HashSet<u32>> = HashMap::new();
let mut flow_control: HashMap<SocketAddr, FlowControlState> = HashMap::new();
let mut server_minor_version: HashMap<SocketAddr, u16> = HashMap::new();
loop {
tokio::select! {
req = coord_rx.recv() => {
let Some(req) = req else { return };
match req {
CoordRequest::RegisterChannel { cid, pv_name, conn_tx } => {
let early_waiters = pending_wait_connected
.remove(&cid)
.unwrap_or_default();
channels.insert(cid, ChannelInner {
cid,
pv_name: pv_name.clone(),
state: ChannelState::Searching,
sid: 0,
native_type: None,
element_count: 0,
server_addr: None,
access_rights: AccessRights::from_u32(0),
connect_waiters: early_waiters,
conn_tx,
reconnect_count: 0,
last_connected_at: None,
});
if let Some(server_addr) = pending_found.remove(&cid) {
let ch = channels.get_mut(&cid).unwrap();
ch.state = ChannelState::Connecting;
ch.server_addr = Some(server_addr);
server_channels.entry(server_addr).or_default().insert(cid);
let _ = transport_tx.send(TransportCommand::CreateChannel {
cid,
pv_name,
server_addr,
});
}
}
CoordRequest::WaitConnected { cid, reply } => {
if let Some(ch) = channels.get_mut(&cid) {
if ch.state == ChannelState::Connected {
let _ = reply.send(());
} else {
ch.connect_waiters.push(reply);
}
} else {
pending_wait_connected
.entry(cid)
.or_default()
.push(reply);
}
}
CoordRequest::Subscribe { cid, subid, mask, deadband, callback_tx, reply } => {
if let Some(ch) = channels.get(&cid) {
let server_addr = ch.server_addr.unwrap_or_else(|| {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
});
let connected = ch.state == ChannelState::Connected;
let data_type = ch.native_type.map(|t| t as u16 + 14);
let count = ch.native_type.map(|_| ch.element_count);
subscriptions.add(subscription::SubscriptionRecord {
subid,
cid,
data_type,
count,
mask,
server_addr,
deadband,
callback_tx,
needs_restore: !connected,
last_value: None,
pending_deliveries: 0,
});
if connected {
let _ = transport_tx.send(TransportCommand::Subscribe {
sid: ch.sid,
data_type: data_type.expect("connected channel has native type"),
count: count.expect("connected channel has element count"),
subid,
mask,
server_addr,
});
}
let _ = reply.send(Ok(()));
} else {
let _ = reply.send(Err(CaError::Disconnected));
}
}
CoordRequest::Unsubscribe { subid } => {
if let Some(rec) = subscriptions.get(subid) {
let cid = rec.cid;
if let Some(ch) = channels.get(&cid) {
if ch.state == ChannelState::Connected {
if let Some(data_type) = rec.data_type {
let _ = transport_tx.send(TransportCommand::Unsubscribe {
sid: ch.sid,
subid,
data_type,
server_addr: ch.server_addr.unwrap(),
});
}
}
}
}
if let Some(rec) = subscriptions.remove(subid) {
flow_control_note_consumed(
&mut flow_control,
rec.server_addr,
rec.pending_deliveries,
&transport_tx,
);
}
}
CoordRequest::MonitorConsumed { subid } => {
if let Some(server_addr) = subscriptions.mark_consumed(subid) {
flow_control_note_consumed(
&mut flow_control,
server_addr,
1,
&transport_tx,
);
}
}
CoordRequest::DropChannel { cid } => {
let sub_ids = subscriptions.for_cid(cid);
for subid in sub_ids {
if let Some(rec) = subscriptions.get(subid) {
if let Some(ch) = channels.get(&cid) {
if ch.state == ChannelState::Connected {
if let Some(data_type) = rec.data_type {
let _ = transport_tx.send(TransportCommand::Unsubscribe {
sid: ch.sid,
subid,
data_type,
server_addr: ch.server_addr.unwrap(),
});
}
}
}
}
if let Some(rec) = subscriptions.remove(subid) {
flow_control_note_consumed(
&mut flow_control,
rec.server_addr,
rec.pending_deliveries,
&transport_tx,
);
}
}
if let Some(ch) = channels.get(&cid) {
if ch.state.is_operational() {
let _ = transport_tx.send(TransportCommand::ClearChannel {
cid,
sid: ch.sid,
server_addr: ch.server_addr.unwrap(),
});
}
match ch.state {
ChannelState::Searching
| ChannelState::Connecting
| ChannelState::Disconnected => {
let _ = search_tx.send(SearchRequest::Cancel { cid });
}
_ => {}
}
if let Some(addr) = ch.server_addr {
remove_server_channel(&mut server_channels, addr, cid);
}
}
channels.remove(&cid);
snapshots.remove(&cid);
let mut affected = HashSet::with_capacity(1);
affected.insert(cid);
drain_waiters_for_cids(&affected, &in_flight);
}
CoordRequest::GetWatchdogDelay { cid, reply } => {
let delay = channels.get(&cid).and_then(|ch| {
if !ch.state.is_operational() {
return None;
}
let addr = ch.server_addr?;
last_rx_at.get(&addr).map(|e| e.value().elapsed())
});
let _ = reply.send(delay);
}
CoordRequest::GetHostMinorProtocol { cid, reply } => {
let v = channels.get(&cid).and_then(|ch| {
if !ch.state.is_operational() {
return None;
}
let addr = ch.server_addr?;
server_minor_version.get(&addr).copied()
});
let _ = reply.send(v);
}
CoordRequest::GetIocConnectionCount { reply } => {
let mut servers = HashSet::<SocketAddr>::new();
for ch in channels.values() {
if let Some(addr) = ch.server_addr {
if ch.state.is_operational() {
servers.insert(addr);
}
}
}
let _ = reply.send(servers.len());
}
CoordRequest::Shutdown { reply } => {
for ch in channels.values() {
if ch.state.is_operational() {
if let Some(addr) = ch.server_addr {
let _ = transport_tx.send(TransportCommand::ClearChannel {
cid: ch.cid,
sid: ch.sid,
server_addr: addr,
});
}
}
}
let _ = reply.send(());
return; }
CoordRequest::ForceRescanServer { server_addr, kind } => {
let is_real_restart = matches!(
kind,
beacon_monitor::BeaconAnomalyKind::IdMismatch
| beacon_monitor::BeaconAnomalyKind::PeriodCollapse
);
if is_real_restart {
diag.beacon_anomalies.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::BeaconAnomaly { server: server_addr });
tracing::warn!(
server = %server_addr,
?kind,
"beacon anomaly detected — IOC may have restarted"
);
metrics::counter!(
"ca_client_beacon_anomalies_total",
"server" => server_addr.to_string()
)
.increment(1);
} else {
tracing::debug!(
server = %server_addr,
"first sighting of beacon source — waking pending searches"
);
metrics::counter!(
"ca_client_beacon_first_sighting_total",
"server" => server_addr.to_string()
)
.increment(1);
}
for ch in channels.values() {
if ch.state == ChannelState::Disconnected
|| ch.state == ChannelState::Searching
{
let _ = search_tx.send(SearchRequest::Schedule {
cid: ch.cid,
pv_name: ch.pv_name.to_string(),
reason: SearchReason::BeaconAnomaly,
});
}
}
}
CoordRequest::BeaconArrival { server_addr, anomaly } => {
let states = channels.values().map(|ch| (ch.state, ch.server_addr));
for target in beacon_arrival_targets(states, server_addr) {
let _ = transport_tx.send(
TransportCommand::BeaconArrivalNotify {
server_addr: target,
anomaly,
},
);
}
}
}
}
resp = search_rx.recv() => {
let Some(resp) = resp else { return };
match resp {
SearchResponse::Found { cid, server_addr } => {
if let Some(ch) = channels.get_mut(&cid) {
if ch.state == ChannelState::Searching || ch.state == ChannelState::Disconnected {
if let Some(old_addr) = ch.server_addr {
remove_server_channel(&mut server_channels, old_addr, cid);
}
ch.state = ChannelState::Connecting;
ch.server_addr = Some(server_addr);
server_channels.entry(server_addr).or_default().insert(cid);
let _ = transport_tx.send(TransportCommand::CreateChannel {
cid,
pv_name: ch.pv_name.to_string(),
server_addr,
});
}
} else {
pending_found.insert(cid, server_addr);
}
}
}
}
evt = transport_rx.recv() => {
let Some(evt) = evt else { return };
match evt {
TransportEvent::ChannelCreated { cid, sid, data_type, element_count, access, server_addr } => {
if let Some(ch) = channels.get_mut(&cid) {
let was_disconnected = matches!(ch.state, ChannelState::Disconnected);
let dbr_type = DbFieldType::from_u16(data_type).ok();
ch.state = ChannelState::Connected;
ch.sid = sid;
ch.native_type = dbr_type;
ch.element_count = element_count;
ch.server_addr = Some(server_addr);
ch.access_rights = access;
ch.last_connected_at = Some(std::time::Instant::now());
if let Some(dbr) = dbr_type {
snapshots.insert(
cid,
types::ChannelSnapshotPublic {
sid,
native_type: dbr,
element_count,
server_addr,
access_rights: access,
state: ChannelState::Connected,
},
);
} else {
snapshots.remove(&cid);
}
if was_disconnected {
tracing::info!(pv = %ch.pv_name, cid, sid, server = %server_addr, "channel reconnected");
} else {
tracing::info!(pv = %ch.pv_name, cid, sid, server = %server_addr, "channel connected");
}
metrics::counter!("ca_client_connections_total", "server" => server_addr.to_string()).increment(1);
metrics::gauge!("ca_client_channels_connected").increment(1.0);
search_attempts.remove(&cid);
for waiter in ch.connect_waiters.drain(..) {
let _ = waiter.send(());
}
let _ = ch.conn_tx.send(ConnectionEvent::Connected);
let _ = ch.conn_tx.send(ConnectionEvent::AccessRightsChanged {
read: access.read,
write: access.write,
});
let (restored, stale) = subscriptions.restore_for_channel(
cid,
sid,
data_type,
element_count,
server_addr,
&transport_tx,
);
diag.connections.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::Connected { pv: ch.pv_name.to_string(), server: server_addr });
if restored > 0 || stale > 0 {
diag.reconnections.fetch_add(1, Ordering::Relaxed);
diag.subscriptions_restored.fetch_add(restored as u64, Ordering::Relaxed);
diag.subscriptions_stale.fetch_add(stale as u64, Ordering::Relaxed);
diag.record(DiagEvent::Reconnected { pv: ch.pv_name.to_string(), restored, stale });
eprintln!("CA: {}: restored {restored} subscriptions ({stale} stale removed)", ch.pv_name);
}
let _ = search_tx.send(SearchRequest::ConnectResult {
cid,
success: true,
server_addr,
});
}
}
TransportEvent::MonitorData { subid, data_type, count, data } => {
use subscription::MonitorDeliveryOutcome;
match subscriptions.on_monitor_data(subid, data_type, count, &data) {
MonitorDeliveryOutcome::Queued(server_addr) => {
flow_control_note_queued(
&mut flow_control,
server_addr,
&transport_tx,
);
}
MonitorDeliveryOutcome::Dropped(_server_addr) => {
diag.dropped_monitors.fetch_add(1, Ordering::Relaxed);
tracing::warn!(subid, "monitor dropped (consumer queue full)");
metrics::counter!("ca_client_dropped_monitors_total").increment(1);
}
MonitorDeliveryOutcome::Filtered
| MonitorDeliveryOutcome::NotFound => {}
}
}
TransportEvent::AccessRightsChanged { cid, access } => {
if let Some(ch) = channels.get_mut(&cid) {
ch.access_rights = access;
let _ = ch.conn_tx.send(ConnectionEvent::AccessRightsChanged {
read: access.read,
write: access.write,
});
if let Some(mut snap) = snapshots.get_mut(&cid) {
snap.access_rights = access;
}
}
}
TransportEvent::ChannelCreateFailed { cid } => {
if let Some(ch) = channels.get_mut(&cid) {
let server_addr = ch.server_addr;
ch.state = ChannelState::Disconnected;
snapshots.remove(&cid);
let _ = ch.conn_tx.send(ConnectionEvent::Disconnected);
let _ = search_tx.send(SearchRequest::Schedule {
cid,
pv_name: ch.pv_name.to_string(),
reason: SearchReason::Reconnect,
});
if let Some(addr) = server_addr {
let _ = search_tx.send(SearchRequest::ConnectResult {
cid,
success: false,
server_addr: addr,
});
}
}
}
TransportEvent::ServerError {
eca_status,
original_request,
message,
server_addr,
} => {
let annotated = match original_request {
Some(cmd) => {
if message.is_empty() {
format!("(while processing cmd={cmd})")
} else {
format!("{message} (while processing cmd={cmd})")
}
}
None => message,
};
types::dispatch_exception(
&exception_slot,
types::CaException {
kind: types::CaExceptionKind::ServerError,
message: annotated,
server_addr: Some(server_addr),
pv_name: None,
status: Some(eca_status),
},
);
}
TransportEvent::TcpClosed { server_addr } => {
let n_affected = server_channels
.get(&server_addr)
.map(|s| s.len())
.unwrap_or(0);
tracing::warn!(server = %server_addr, channels = n_affected, "TCP circuit closed");
metrics::counter!("ca_client_tcp_closed_total", "server" => server_addr.to_string()).increment(1);
flow_control.remove(&server_addr);
last_rx_at.remove(&server_addr);
server_minor_version.remove(&server_addr);
handle_disconnect(&mut channels, &mut subscriptions, &mut server_channels, &search_tx, server_addr, &diag, &in_flight, &snapshots);
}
TransportEvent::ServerDisconnect { cid, server_addr } => {
if let Some(ch) = channels.get_mut(&cid) {
if ch.server_addr == Some(server_addr) {
ch.state = ChannelState::Disconnected;
snapshots.remove(&cid);
let _ = ch.conn_tx.send(ConnectionEvent::Disconnected);
let pv_name = ch.pv_name.to_string();
let cids = vec![cid];
let cleared = subscriptions.mark_disconnected(&cids);
for (addr, count) in cleared {
flow_control_note_consumed(
&mut flow_control,
addr,
count,
&transport_tx,
);
}
let mut affected = HashSet::with_capacity(1);
affected.insert(cid);
drain_waiters_for_cids(&affected, &in_flight);
let _ = search_tx.send(SearchRequest::Schedule {
cid,
pv_name: pv_name.clone(),
reason: SearchReason::Reconnect,
});
types::dispatch_exception(
&exception_slot,
types::CaException {
kind: types::CaExceptionKind::ServerDisconnect,
message: "server-initiated channel close".to_string(),
server_addr: Some(server_addr),
pv_name: Some(pv_name),
status: None,
},
);
}
}
}
TransportEvent::CircuitUnresponsive { server_addr } => {
diag.unresponsive_events.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::Unresponsive { server: server_addr });
tracing::warn!(server = %server_addr, "circuit unresponsive (echo timeout)");
metrics::counter!("ca_client_unresponsive_total", "server" => server_addr.to_string()).increment(1);
for ch in channels.values_mut() {
if ch.server_addr == Some(server_addr)
&& ch.state == ChannelState::Connected
{
ch.state = ChannelState::Unresponsive;
if let Some(mut snap) = snapshots.get_mut(&ch.cid) {
snap.state = ChannelState::Unresponsive;
}
let _ = ch.conn_tx.send(ConnectionEvent::Unresponsive);
}
}
}
TransportEvent::CircuitResponsive { server_addr } => {
diag.record(DiagEvent::Responsive { server: server_addr });
tracing::info!(server = %server_addr, "circuit responsive again");
for ch in channels.values_mut() {
if ch.server_addr == Some(server_addr)
&& ch.state == ChannelState::Unresponsive
{
ch.state = ChannelState::Connected;
if let Some(mut snap) = snapshots.get_mut(&ch.cid) {
snap.state = ChannelState::Connected;
}
let _ = ch.conn_tx.send(ConnectionEvent::Connected);
}
}
}
TransportEvent::ServerVersion { server_addr, minor_version } => {
server_minor_version.insert(server_addr, minor_version);
}
TransportEvent::ServerConnected { server_addr } => {
let _ = beacon_ctrl_tx.send(
beacon_monitor::BeaconControl::ResetServer { server_addr },
);
}
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn handle_disconnect(
channels: &mut HashMap<u32, ChannelInner>,
subscriptions: &mut SubscriptionRegistry,
server_channels: &mut HashMap<SocketAddr, HashSet<u32>>,
search_tx: &mpsc::UnboundedSender<SearchRequest>,
server_addr: SocketAddr,
diag: &CaDiagnostics,
in_flight: &types::InFlightOps,
snapshots: &ChannelSnapshots,
) {
let mut affected_cids = Vec::new();
let now = std::time::Instant::now();
for ch in channels.values_mut() {
if ch.server_addr == Some(server_addr)
&& (ch.state.is_operational() || ch.state == ChannelState::Connecting)
{
ch.state = ChannelState::Disconnected;
snapshots.remove(&ch.cid);
affected_cids.push(ch.cid);
let _ = ch.conn_tx.send(ConnectionEvent::Disconnected);
let sustained = ch
.last_connected_at
.map(|t| now.duration_since(t).as_secs() > 30)
.unwrap_or(false);
if sustained {
ch.reconnect_count = 0;
} else {
ch.reconnect_count = ch.reconnect_count.saturating_add(1);
}
let _ = search_tx.send(SearchRequest::Schedule {
cid: ch.cid,
pv_name: ch.pv_name.to_string(),
reason: SearchReason::Reconnect,
});
}
}
if !affected_cids.is_empty() {
diag.disconnections.fetch_add(1, Ordering::Relaxed);
diag.record(DiagEvent::Disconnected {
server: server_addr,
channels: affected_cids.len(),
});
tracing::warn!(
server = %server_addr,
affected = affected_cids.len(),
"disconnect: scheduling reconnect for affected channels"
);
metrics::counter!("ca_client_disconnections_total", "server" => server_addr.to_string())
.increment(1);
metrics::gauge!("ca_client_channels_connected").decrement(affected_cids.len() as f64);
}
server_channels.remove(&server_addr);
let _ = subscriptions.mark_disconnected(&affected_cids);
let affected: HashSet<u32> = affected_cids.into_iter().collect();
drain_waiters_for_cids(&affected, in_flight);
}
pub(crate) fn drain_waiters_for_cids(cids: &HashSet<u32>, in_flight: &types::InFlightOps) {
let stale_reads: Vec<u32> = in_flight
.reads
.iter()
.filter(|entry| cids.contains(&entry.value().cid()))
.map(|entry| *entry.key())
.collect();
for ioid in stale_reads {
if let Some((_, waiter)) = in_flight.reads.remove(&ioid) {
waiter.send(Err(CaError::Disconnected));
}
}
let stale_writes: Vec<u32> = in_flight
.writes
.iter()
.filter(|entry| cids.contains(&entry.value().0))
.map(|entry| *entry.key())
.collect();
for ioid in stale_writes {
if let Some((_, (_, sender))) = in_flight.writes.remove(&ioid) {
let _ = sender.send(Err(CaError::Disconnected));
}
}
}
fn beacon_arrival_targets<I>(channel_states: I, beacon_addr: SocketAddr) -> Vec<SocketAddr>
where
I: IntoIterator<Item = (ChannelState, Option<SocketAddr>)>,
{
let beacon_unspec = beacon_addr.ip().is_unspecified();
let mut found_exact = false;
let mut port_targets: HashSet<SocketAddr> = HashSet::new();
for (state, addr_opt) in channel_states {
if !state.is_operational() {
continue;
}
let Some(addr) = addr_opt else { continue };
if !beacon_unspec && addr == beacon_addr {
found_exact = true;
break;
}
if addr.port() == beacon_addr.port() {
port_targets.insert(addr);
}
}
if found_exact {
vec![beacon_addr]
} else {
port_targets.into_iter().collect()
}
}
fn remove_server_channel(
server_channels: &mut HashMap<SocketAddr, HashSet<u32>>,
server_addr: SocketAddr,
cid: u32,
) {
if let Some(set) = server_channels.get_mut(&server_addr) {
set.remove(&cid);
if set.is_empty() {
server_channels.remove(&server_addr);
}
}
}
fn resolve_host(host: &str, port: u16) -> CaResult<SocketAddr> {
if let Ok(ip) = host.parse::<Ipv4Addr>() {
return Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)));
}
use std::net::ToSocketAddrs;
let addr_str = format!("{host}:{port}");
let addrs: Vec<SocketAddr> = addr_str
.to_socket_addrs()
.map_err(|e| CaError::Protocol(format!("cannot resolve '{host}': {e}")))?
.collect();
addrs
.iter()
.find(|a| a.is_ipv4())
.or(addrs.first())
.copied()
.ok_or_else(|| CaError::Protocol(format!("no addresses for '{host}'")))
}
#[derive(Debug, Clone)]
pub(crate) struct AddrEntry {
pub sock: SocketAddr,
pub hostname: Option<String>,
pub port: u16,
}
impl AddrEntry {
pub fn new(sock: SocketAddr, hostname: Option<String>, port: u16) -> Self {
Self {
sock,
hostname,
port,
}
}
pub fn refresh_dns(&mut self) -> CaResult<SocketAddr> {
let Some(host) = self.hostname.as_deref() else {
return Ok(self.sock);
};
let new_sock = resolve_host(host, self.port)?;
self.sock = new_sock;
Ok(new_sock)
}
}
pub(crate) fn parse_addr_list_with_hostnames() -> CaResult<Vec<AddrEntry>> {
let mut addrs: Vec<AddrEntry> = Vec::new();
let default_port = epics_base_rs::runtime::env::get("EPICS_CA_SERVER_PORT")
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(CA_SERVER_PORT);
if let Some(list) = epics_base_rs::runtime::env::get("EPICS_CA_ADDR_LIST") {
for entry in list.split_whitespace() {
let (host_raw, port) = if entry.contains(':') {
if let Some((h, p)) = entry.rsplit_once(':') {
let port: u16 = match p.parse() {
Ok(v) => v,
Err(_) => continue,
};
(h.to_string(), port)
} else {
(entry.to_string(), default_port)
}
} else {
(entry.to_string(), default_port)
};
let hostname = if host_raw.parse::<Ipv4Addr>().is_ok() {
None
} else {
Some(host_raw.clone())
};
match resolve_host(&host_raw, port) {
Ok(sock) => addrs.push(AddrEntry::new(sock, hostname, port)),
Err(e) => tracing::debug!(token = %entry, error = %e,
"EPICS_CA_ADDR_LIST: dropped unresolvable entry"),
}
}
}
let auto_addr = epics_base_rs::runtime::env::get_or("EPICS_CA_AUTO_ADDR_LIST", "YES");
if auto_addr.eq_ignore_ascii_case("YES") {
let server_port = default_port;
for bcast in crate::server::addr_list::discover_broadcast_addrs() {
let sock = SocketAddr::V4(SocketAddrV4::new(bcast, server_port));
if !addrs.iter().any(|e| e.sock == sock) {
addrs.push(AddrEntry::new(sock, None, server_port));
}
}
let fallback = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::BROADCAST, server_port));
if !addrs.iter().any(|e| e.sock == fallback) {
addrs.push(AddrEntry::new(fallback, None, server_port));
}
}
Ok(addrs)
}
#[cfg(test)]
mod addr_entry_tests {
use super::*;
#[test]
fn ip_literal_has_no_hostname() {
let entry = AddrEntry::new(
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5064)),
None,
5064,
);
assert!(entry.hostname.is_none());
}
#[test]
fn refresh_noop_for_literal_ip() {
let original_sock = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5064));
let mut entry = AddrEntry::new(original_sock, None, 5064);
let refreshed = entry.refresh_dns().expect("noop refresh succeeds");
assert_eq!(refreshed, original_sock);
}
}
fn expand_pv_name(name: &str) -> String {
if epics_base_rs::runtime::env::get_or("EPICS_CA_USE_SHELL_VARS", "NO")
.eq_ignore_ascii_case("YES")
{
expand_shell_vars(name)
} else {
name.to_string()
}
}
fn expand_shell_vars(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'$' && i + 1 < bytes.len() {
let close = match bytes[i + 1] {
b'{' => Some(b'}'),
b'(' => Some(b')'),
_ => None,
};
if let Some(end) = close {
if let Some(j) = bytes[i + 2..].iter().position(|&b| b == end) {
let name = &s[i + 2..i + 2 + j];
let value = epics_base_rs::runtime::env::get(name).unwrap_or_default();
out.push_str(&value);
i += 3 + j;
continue;
}
}
}
out.push(s.as_bytes()[i] as char);
i += 1;
}
out
}
#[cfg(feature = "experimental-rust-tls")]
fn parse_tls_sni_map() -> Vec<(SocketAddr, String)> {
let Some(list) = epics_base_rs::runtime::env::get("EPICS_CA_TLS_SNI_MAP") else {
return Vec::new();
};
let mut out = Vec::new();
for entry in list.split_whitespace() {
let Some((addr_part, host)) = entry.split_once('=') else {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry missing '=', skipping");
continue;
};
if host.is_empty() {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry has empty hostname, skipping");
continue;
}
let addr = if addr_part.contains(':') {
match addr_part.parse::<SocketAddr>() {
Ok(a) => a,
Err(_) => {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry has unparseable IP:port, skipping");
continue;
}
}
} else {
match addr_part.parse::<std::net::IpAddr>() {
Ok(ip) => SocketAddr::new(ip, 0),
Err(_) => {
tracing::warn!(entry = %entry,
"EPICS_CA_TLS_SNI_MAP entry has unparseable IP, skipping");
continue;
}
}
};
out.push((addr, host.to_string()));
}
out
}
pub(crate) fn parse_nameserver_list() -> Vec<(SocketAddr, Option<String>)> {
let Some(list) = epics_base_rs::runtime::env::get("EPICS_CA_NAME_SERVERS") else {
return Vec::new();
};
let mut out = Vec::new();
for entry in list.split_whitespace() {
if entry.contains(':') {
if let Ok(addr) = entry.parse::<SocketAddr>() {
out.push((addr, None));
continue;
}
let Some((host, port_str)) = entry.rsplit_once(':') else {
continue;
};
let Ok(port) = port_str.parse::<u16>() else {
continue;
};
if let Ok(addr) = resolve_host(host, port) {
let hostname = if host.parse::<std::net::IpAddr>().is_ok() {
None
} else {
Some(host.to_string())
};
out.push((addr, hostname));
}
} else {
if let Ok(addr) = resolve_host(entry, CA_SERVER_PORT) {
let hostname = if entry.parse::<std::net::IpAddr>().is_ok() {
None
} else {
Some(entry.to_string())
};
out.push((addr, hostname));
}
}
}
out
}
#[cfg(test)]
mod beacon_arrival_routing_tests {
use super::*;
fn addr(s: &str) -> SocketAddr {
s.parse().unwrap()
}
#[test]
fn exact_match_dominates() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.2:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert_eq!(targets, vec![addr("10.0.0.1:5064")]);
}
#[test]
fn unspecified_addr_falls_back_to_port_match() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.2:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.3:6000"))),
];
let mut targets = beacon_arrival_targets(states, addr("0.0.0.0:5064"));
targets.sort();
assert_eq!(
targets,
vec![addr("10.0.0.1:5064"), addr("10.0.0.2:5064")],
":6000 must NOT be a target for a :5064 beacon"
);
}
#[test]
fn multi_homed_falls_back_to_port_match() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.2:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert_eq!(
targets,
vec![addr("10.0.0.2:5064")],
"multi-homed IOC must be reachable via port-only fallback"
);
}
#[test]
fn non_operational_channels_do_not_match() {
let states = vec![
(ChannelState::Searching, Some(addr("10.0.0.1:5064"))),
(ChannelState::Disconnected, Some(addr("10.0.0.2:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert!(
targets.is_empty(),
"non-operational channels must not generate watchdog notifies"
);
}
#[test]
fn no_match_returns_empty() {
let states = vec![(ChannelState::Connected, Some(addr("10.0.0.1:5064")))];
let targets = beacon_arrival_targets(states, addr("10.0.0.99:5065"));
assert!(targets.is_empty());
}
#[test]
fn exact_match_emits_single_notify() {
let states = vec![
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
(ChannelState::Connected, Some(addr("10.0.0.1:5064"))),
];
let targets = beacon_arrival_targets(states, addr("10.0.0.1:5064"));
assert_eq!(targets, vec![addr("10.0.0.1:5064")]);
}
}
#[cfg(test)]
mod tls_sni_config_tests {
#[cfg(feature = "experimental-rust-tls")]
use super::*;
#[cfg(feature = "experimental-rust-tls")]
#[test]
fn tls_server_name_round_trip() {
let mut cfg = CaClientConfig::default();
assert!(cfg.tls_server_name.is_none(), "default must be None");
cfg.tls_server_name = Some("ioc.example.com".into());
assert_eq!(cfg.tls_server_name.as_deref(), Some("ioc.example.com"));
}
}
#[cfg(test)]
mod waiter_drain_tests {
use super::*;
use std::collections::HashSet;
use tokio::sync::oneshot;
#[tokio::test(flavor = "current_thread")]
async fn drain_wakes_matching_cid_only() {
let in_flight = types::InFlightOps::new();
let (rtx_42, rrx_42) = oneshot::channel();
let (rtx_99, rrx_99) = oneshot::channel();
let (wtx_42, wrx_42) = oneshot::channel();
let (wtx_99, wrx_99) = oneshot::channel();
in_flight.reads.insert(
1001,
types::ReadWaiter::OneShot {
cid: 42,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx_42,
},
);
in_flight.reads.insert(
2001,
types::ReadWaiter::OneShot {
cid: 99,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx_99,
},
);
in_flight.writes.insert(1002, (42, wtx_42));
in_flight.writes.insert(2002, (99, wtx_99));
let mut affected = HashSet::new();
affected.insert(42u32);
drain_waiters_for_cids(&affected, &in_flight);
assert!(!in_flight.reads.contains_key(&1001));
assert!(!in_flight.writes.contains_key(&1002));
assert!(matches!(rrx_42.await, Ok(Err(CaError::Disconnected))));
assert!(matches!(wrx_42.await, Ok(Err(CaError::Disconnected))));
assert!(in_flight.reads.contains_key(&2001));
assert!(in_flight.writes.contains_key(&2002));
drop(in_flight);
assert!(rrx_99.await.is_err()); assert!(wrx_99.await.is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn drain_with_empty_cid_set_is_noop() {
let in_flight = types::InFlightOps::new();
let (rtx, rrx) = oneshot::channel();
let (wtx, wrx) = oneshot::channel();
in_flight.reads.insert(
10,
types::ReadWaiter::OneShot {
cid: 1,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx,
},
);
in_flight.writes.insert(20, (2, wtx));
let affected: HashSet<u32> = HashSet::new();
drain_waiters_for_cids(&affected, &in_flight);
assert!(in_flight.reads.contains_key(&10));
assert!(in_flight.writes.contains_key(&20));
drop(in_flight);
assert!(rrx.await.is_err());
assert!(wrx.await.is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn response_arrives_before_disconnect_drain() {
let in_flight = types::InFlightOps::new();
let (rtx, rrx) = oneshot::channel();
in_flight.reads.insert(
100,
types::ReadWaiter::OneShot {
cid: 7,
mode: types::ReadReplyMode::Raw,
reply_tx: rtx,
},
);
if let Some((_, waiter)) = in_flight.reads.remove(&100) {
waiter.send(Ok(types::ReadReply::Raw {
data_type: 6,
count: 1,
data: vec![1, 0, 0, 0],
}));
}
let mut affected = HashSet::new();
affected.insert(7u32);
drain_waiters_for_cids(&affected, &in_flight);
let result = rrx.await.expect("oneshot still alive");
assert!(matches!(
result,
Ok(types::ReadReply::Raw {
data_type: 6,
count: 1,
..
})
));
}
}