use socket2::{Domain, Protocol, Socket, Type};
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use tokio::net::UdpSocket;
use crate::protocol::*;
use epics_base_rs::error::CaResult;
use epics_base_rs::net::{enable_so_rxq_ovfl_for_socket, recv_from_with_drop_count_socket};
use epics_base_rs::server::database::PvDatabase;
fn plan_responder_specs(
intf_addrs: Vec<Ipv4Addr>,
mcast_addrs: &[Ipv4Addr],
) -> Vec<(Ipv4Addr, Vec<Ipv4Addr>)> {
let intfs = if intf_addrs.is_empty() {
vec![Ipv4Addr::UNSPECIFIED]
} else {
intf_addrs
};
intfs
.into_iter()
.map(|bind_ip| (bind_ip, mcast_addrs.to_vec()))
.collect()
}
pub async fn run_udp_search_responder(
db: Arc<PvDatabase>,
port: u16,
tcp_port: u16,
intf_addrs: Vec<Ipv4Addr>,
ignore_addrs: Vec<Ipv4Addr>,
mcast_addrs: Vec<Ipv4Addr>,
) -> CaResult<()> {
let specs = plan_responder_specs(intf_addrs, &mcast_addrs);
let mut handles = Vec::with_capacity(specs.len());
for (bind_ip, mcast_for_intf) in specs {
let db_t = db.clone();
let ignore_t = ignore_addrs.clone();
let handle = epics_base_rs::runtime::task::spawn(async move {
run_single_responder(db_t, bind_ip, port, tcp_port, ignore_t, mcast_for_intf).await
});
handles.push(handle);
}
let mut handles_iter = handles.into_iter();
let result = if let Some(first) = handles_iter.next() {
match first.await {
Ok(inner) => inner,
Err(e) => Err(epics_base_rs::error::CaError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
))),
}
} else {
Ok(())
};
for h in handles_iter {
h.abort();
}
result
}
async fn run_single_responder(
db: Arc<PvDatabase>,
bind_ip: Ipv4Addr,
port: u16,
tcp_port: u16,
ignore_addrs: Vec<Ipv4Addr>,
mcast_groups: Vec<Ipv4Addr>,
) -> CaResult<()> {
let socket = bind_responder_socket(bind_ip, port)?;
for group in &mcast_groups {
match socket.join_multicast_v4(*group, bind_ip) {
Ok(()) => tracing::debug!(
target: "epics_ca_rs::server::udp",
%bind_ip,
group = %group,
"joined multicast group on responder socket"
),
Err(e) => tracing::warn!(
target: "epics_ca_rs::server::udp",
%bind_ip,
group = %group,
error = %e,
"CA server IP_ADD_MEMBERSHIP failed — \
SEARCH on this group will not reach this NIC"
),
}
}
let socket = Arc::new(socket);
let bcast_socket: Option<Arc<UdpSocket>> = {
#[cfg(any(windows, target_os = "windows"))]
{
None
}
#[cfg(not(any(windows, target_os = "windows")))]
{
super::addr_list::broadcast_for_ip(bind_ip).and_then(|bcast_ip| {
match bind_responder_socket(bcast_ip, port) {
Ok(s) => Some(Arc::new(s)),
Err(e) => {
tracing::warn!(
target: "epics_ca_rs::server::udp",
%bind_ip,
%bcast_ip,
error = %e,
"CA server bcast responder bind failed; broadcast SEARCHes \
to this interface will not be answered"
);
None
}
}
})
}
};
let udp_rl = Arc::new(UdpRateLimiter::from_env());
let primary = recv_loop(
socket.clone(),
db.clone(),
bind_ip,
tcp_port,
ignore_addrs.clone(),
udp_rl.clone(),
);
match bcast_socket {
Some(bsock) => {
let secondary = recv_loop(bsock, db, bind_ip, tcp_port, ignore_addrs, udp_rl);
tokio::try_join!(primary, secondary).map(|_| ())
}
None => primary.await,
}
}
fn bind_responder_socket(bind_ip: Ipv4Addr, port: u16) -> CaResult<UdpSocket> {
let sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
#[cfg(not(windows))]
{
sock.set_reuse_address(true)?;
#[cfg(unix)]
sock.set_reuse_port(true)?;
}
#[cfg(target_os = "linux")]
{
let _ = sock.set_multicast_all_v4(false);
}
sock.set_nonblocking(true)?;
sock.bind(&std::net::SocketAddrV4::new(bind_ip, port).into())?;
let socket = UdpSocket::from_std(sock.into())?;
socket.set_broadcast(true)?;
let _ = socket.set_multicast_ttl_v4(epics_base_rs::runtime::net::ca_mcast_ttl());
if let Err(e) = enable_so_rxq_ovfl_for_socket(&socket) {
tracing::trace!(
target: "epics_ca_rs::server::udp",
%bind_ip,
error = %e,
"SO_RXQ_OVFL enable failed (non-fatal)"
);
}
Ok(socket)
}
async fn recv_loop(
socket: Arc<UdpSocket>,
db: Arc<PvDatabase>,
bind_ip: Ipv4Addr,
tcp_port: u16,
ignore_addrs: Vec<Ipv4Addr>,
udp_rl: Arc<UdpRateLimiter>,
) -> CaResult<()> {
let mut buf = vec![0u8; 64 * 1024];
let mut prev_drops: u32 = 0;
let mut peek_buf = vec![0u8; 64 * 1024];
loop {
let (len, src, drops) = recv_from_with_drop_count_socket(&socket, &mut buf).await?;
if drops != 0 && drops != prev_drops {
tracing::debug!(
target: "epics_ca_rs::server::udp",
%bind_ip,
prev = prev_drops,
drops,
"CA server UDP search responder buffer overflow"
);
}
prev_drops = drops;
if len < CaHeader::SIZE {
continue;
}
if let SocketAddr::V4(v4) = src {
if ignore_addrs.contains(v4.ip()) {
continue;
}
}
if !udp_rl.allow(&src) {
metrics::counter!("ca_server_udp_search_drops_total").increment(1);
continue;
}
let mut current_src = src;
let mut current_buf: Vec<u8> = Vec::with_capacity(64 * 1024);
current_buf.extend_from_slice(&buf[..len]);
let mut client_seq: Option<u32> = None;
let mut client_minor: Option<u16> = None;
let mut send_buf: Vec<u8> = Vec::new();
const UDP_FLUSH_THRESHOLD: usize = 1024;
'parse: loop {
let mut offset = 0;
while offset + CaHeader::SIZE <= current_buf.len() {
let hdr = match CaHeader::from_bytes(¤t_buf[offset..]) {
Ok(h) => h,
Err(_) => break,
};
if (hdr.postsize as usize) & 0x7 != 0 {
break;
}
let payload_size = hdr.postsize as usize;
let msg_len = CaHeader::SIZE + payload_size;
if offset + msg_len > current_buf.len() {
break;
}
if hdr.cmmd != CA_PROTO_VERSION && hdr.cmmd != CA_PROTO_SEARCH {
break;
}
if hdr.cmmd == CA_PROTO_VERSION {
const CA_MINIMUM_SUPPORTED_VERSION: u16 = 4;
if hdr.count < CA_MINIMUM_SUPPORTED_VERSION {
break;
}
client_minor = Some(client_minor.unwrap_or(0).max(hdr.count));
if hdr.data_type == 1 {
client_seq = Some(hdr.cid);
}
}
if hdr.cmmd == CA_PROTO_SEARCH {
const CA_MINIMUM_SUPPORTED_VERSION: u16 = 4;
if hdr.count < CA_MINIMUM_SUPPORTED_VERSION {
break;
}
if hdr.postsize <= 1 {
offset += msg_len;
continue;
}
let payload_start = offset + CaHeader::SIZE;
let payload_end = payload_start + hdr.postsize as usize;
let payload = ¤t_buf[payload_start..payload_end];
let scan_end = payload.len().saturating_sub(1).max(0);
let pv_name_end = payload[..scan_end]
.iter()
.position(|&b| b == 0)
.unwrap_or(scan_end);
if let Ok(pv_name) = std::str::from_utf8(&payload[..pv_name_end]) {
if db.has_name_from(pv_name, Some(src)).await {
let mut resp = CaHeader::new(CA_PROTO_SEARCH);
resp.postsize = 8;
resp.data_type = tcp_port;
resp.count = 0;
resp.cid = u32::MAX; resp.available = hdr.available;
let mut ver = CaHeader::new(CA_PROTO_VERSION);
ver.count = CA_MINOR_VERSION;
let resp_bytes = resp.to_bytes();
let mut search_payload = [0u8; 8];
search_payload[0..2].copy_from_slice(&CA_MINOR_VERSION.to_be_bytes());
const SEARCH_REPLY_LEN: usize = CaHeader::SIZE + 8;
if !send_buf.is_empty()
&& send_buf.len() + SEARCH_REPLY_LEN > UDP_FLUSH_THRESHOLD
{
flush_send_buf(
&socket,
current_src,
&mut send_buf,
client_minor,
client_seq,
&bind_ip,
)
.await;
}
if send_buf.is_empty() {
send_buf.extend_from_slice(&ver.to_bytes());
}
send_buf.extend_from_slice(&resp_bytes);
send_buf.extend_from_slice(&search_payload);
}
}
}
offset += msg_len;
}
let next_datagram = loop {
match socket.try_recv_from(&mut peek_buf) {
Ok((peek_len, peek_src)) => {
if peek_len < CaHeader::SIZE {
continue;
}
if let SocketAddr::V4(v4) = peek_src {
if ignore_addrs.contains(v4.ip()) {
continue;
}
}
if !udp_rl.allow(&peek_src) {
metrics::counter!("ca_server_udp_search_drops_total").increment(1);
continue;
}
break Some((peek_len, peek_src));
}
Err(_) => break None, }
};
match next_datagram {
Some((peek_len, peek_src)) => {
if peek_src != current_src {
flush_send_buf(
&socket,
current_src,
&mut send_buf,
client_minor,
client_seq,
&bind_ip,
)
.await;
current_src = peek_src;
client_seq = None;
client_minor = None;
}
current_buf.clear();
current_buf.extend_from_slice(&peek_buf[..peek_len]);
continue 'parse;
}
None => break 'parse, }
} if !send_buf.is_empty() {
flush_send_buf(
&socket,
current_src,
&mut send_buf,
client_minor,
client_seq,
&bind_ip,
)
.await;
}
}
}
async fn flush_send_buf(
socket: &UdpSocket,
src: SocketAddr,
send_buf: &mut Vec<u8>,
client_minor: Option<u16>,
client_seq: Option<u32>,
bind_ip: &Ipv4Addr,
) {
if send_buf.is_empty() {
return;
}
let payload: &[u8] = if client_minor.is_some_and(|m| m >= 11) {
if send_buf.len() >= CaHeader::SIZE {
let mut ver = CaHeader::new(CA_PROTO_VERSION);
ver.count = CA_MINOR_VERSION;
if let Some(seq) = client_seq {
ver.cid = seq;
ver.data_type = 1;
}
let bytes = ver.to_bytes();
send_buf[..CaHeader::SIZE].copy_from_slice(&bytes);
}
&send_buf[..]
} else {
if send_buf.len() >= CaHeader::SIZE {
&send_buf[CaHeader::SIZE..]
} else {
send_buf.clear();
return;
}
};
if let Err(e) = socket.send_to(payload, src).await {
tracing::warn!(
target: "epics_ca_rs::server::udp",
%bind_ip,
dst = %src,
payload_len = payload.len(),
error = %e,
"CA server UDP SEARCH-reply batch send failed"
);
metrics::counter!("ca_server_udp_search_reply_send_failures_total").increment(1);
}
send_buf.clear();
}
struct UdpRateLimiter {
enabled: bool,
cap_per_sec: u32,
counts:
std::sync::Mutex<std::collections::HashMap<std::net::IpAddr, (std::time::Instant, u32)>>,
}
impl UdpRateLimiter {
fn from_env() -> Self {
let cap = epics_base_rs::runtime::env::get("EPICS_CAS_UDP_SEARCH_RATE_LIMIT")
.and_then(|s| s.parse().ok())
.unwrap_or(0u32);
Self {
enabled: cap > 0,
cap_per_sec: cap,
counts: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
fn allow(&self, src: &SocketAddr) -> bool {
if !self.enabled {
return true;
}
let ip = src.ip();
let now = std::time::Instant::now();
let mut counts = self.counts.lock().unwrap();
let entry = counts.entry(ip).or_insert((now, 0));
if now.duration_since(entry.0) >= std::time::Duration::from_secs(1) {
entry.0 = now;
entry.1 = 0;
}
if entry.1 >= self.cap_per_sec {
return false;
}
entry.1 += 1;
if counts.len() > 4096 {
counts.retain(|_, (t, _)| {
now.saturating_duration_since(*t) <= std::time::Duration::from_secs(5)
});
}
true
}
}
#[cfg(test)]
mod mr_r8_responder_plan_tests {
use super::plan_responder_specs;
use std::net::Ipv4Addr;
#[test]
fn mr_r8_wildcard_with_mcast_groups_yields_one_responder() {
let groups = vec![Ipv4Addr::new(224, 0, 0, 100), Ipv4Addr::new(224, 0, 0, 101)];
let specs = plan_responder_specs(Vec::new(), &groups);
assert_eq!(
specs.len(),
1,
"wildcard config must produce exactly ONE responder \
socket, not one-per-multicast-group"
);
let (bind_ip, mcast) = &specs[0];
assert_eq!(*bind_ip, Ipv4Addr::UNSPECIFIED);
assert_eq!(
mcast, &groups,
"the single wildcard responder must own ALL multicast \
group joins (C `conf->udp` parity)"
);
let specs2 = plan_responder_specs(vec![Ipv4Addr::UNSPECIFIED], &groups);
assert_eq!(specs2.len(), 1);
assert_eq!(specs2[0].1, groups);
}
#[test]
fn mr_r8_specific_intfs_each_own_all_mcast_groups() {
let groups = vec![Ipv4Addr::new(224, 0, 0, 200)];
let intfs = vec![Ipv4Addr::new(10, 0, 0, 1), Ipv4Addr::new(10, 0, 0, 2)];
let specs = plan_responder_specs(intfs.clone(), &groups);
assert_eq!(specs.len(), 2, "one responder per interface entry");
for (i, (bind_ip, mcast)) in specs.iter().enumerate() {
assert_eq!(*bind_ip, intfs[i]);
assert_eq!(
mcast, &groups,
"each interface responder joins every multicast group \
on its own socket (C `conf->udp` parity)"
);
}
}
#[test]
fn mr_r8_no_mcast_groups_yields_single_plain_responder() {
let specs = plan_responder_specs(Vec::new(), &[]);
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].0, Ipv4Addr::UNSPECIFIED);
assert!(specs[0].1.is_empty());
}
}