use std::collections::HashMap;
use std::io;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket as StdUdpSocket};
use tokio::net::UdpSocket;
use crate::protocol::*;
struct RepeaterClient {
sock: StdUdpSocket,
addr: SocketAddr,
}
impl RepeaterClient {
fn new(addr: SocketAddr) -> io::Result<Self> {
let sock = StdUdpSocket::bind("0.0.0.0:0")?;
sock.connect(addr)?;
sock.set_nonblocking(true)?;
Ok(Self { sock, addr })
}
fn send_confirm(&self) -> bool {
let mut confirm = CaHeader::new(CA_PROTO_REPEATER_CONFIRM);
if let SocketAddr::V4(v4) = self.addr {
confirm.available = u32::from_be_bytes(v4.ip().octets());
}
self.sock.send(&confirm.to_bytes()).is_ok()
}
fn send_message(&self, data: &[u8]) -> bool {
match self.sock.send(data) {
Ok(_) => true,
Err(e) => !matches!(
e.kind(),
io::ErrorKind::ConnectionRefused | io::ErrorKind::HostUnreachable
),
}
}
fn verify(&self) -> bool {
let port = match self.addr {
SocketAddr::V4(v4) => v4.port(),
_ => return false,
};
match StdUdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)) {
Ok(_) => false, Err(e) if e.kind() == io::ErrorKind::AddrInUse => true, Err(_) => true, }
}
}
pub async fn run_repeater() -> io::Result<()> {
run_repeater_with_debug(0).await
}
pub async fn run_repeater_with_debug(debug: u8) -> io::Result<()> {
use socket2::{Domain, Protocol, Socket, Type};
let sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
#[cfg(target_os = "linux")]
{
let _ = sock.set_multicast_all_v4(false);
}
sock.set_nonblocking(true)?;
let bind_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, repeater_port());
sock.bind(&bind_addr.into())?;
#[cfg(not(windows))]
let _ = sock.set_reuse_address(true);
join_beacon_multicast_groups(&sock);
let std_sock: StdUdpSocket = sock.into();
let socket = UdpSocket::from_std(std_sock)?;
if let Err(e) = epics_base_rs::net::enable_so_rxq_ovfl_for_socket(&socket) {
tracing::trace!(
target: "epics_ca_rs::repeater",
error = %e,
"SO_RXQ_OVFL enable failed (non-fatal)"
);
}
if debug > 0 {
eprintln!("CA Repeater: Attached and initialized");
}
let mut clients: HashMap<u16, RepeaterClient> = HashMap::new();
let mut buf = [0u8; 4096];
let mut prev_drops: u32 = 0;
loop {
let (len, src, drops) =
epics_base_rs::net::recv_from_with_drop_count_socket(&socket, &mut buf).await?;
if drops != 0 && drops != prev_drops {
tracing::debug!(
target: "epics_ca_rs::repeater",
prev = prev_drops,
drops,
"CA repeater UDP socket buffer overflow"
);
}
prev_drops = drops;
if len == 0 {
if !is_local_source(src) {
tracing::warn!(
src = %src,
"caRepeater: zero-length registration from non-local source rejected"
);
metrics::counter!("ca_repeater_register_non_loopback_rejects_total").increment(1);
continue;
}
register_client_debug(&mut clients, src, debug);
continue;
}
if len < CaHeader::SIZE {
continue;
}
let Ok(hdr) = CaHeader::from_bytes(&buf[..len]) else {
continue;
};
let action = decode_datagram(&buf[..len], &hdr, src);
if action.register {
if !is_local_source(src) {
tracing::warn!(
src = %src,
"caRepeater: REPEATER_REGISTER from non-local source rejected"
);
metrics::counter!("ca_repeater_register_non_loopback_rejects_total").increment(1);
continue;
}
register_client_debug(&mut clients, src, debug);
}
if let Some(data) = action.fanout {
fan_out(&mut clients, src, &data, debug);
}
}
}
struct DatagramAction {
register: bool,
fanout: Option<Vec<u8>>,
}
fn is_local_source(src: SocketAddr) -> bool {
if src.ip().is_loopback() {
return true;
}
match src {
SocketAddr::V4(v4) => StdUdpSocket::bind(SocketAddrV4::new(*v4.ip(), 0)).is_ok(),
SocketAddr::V6(_) => false,
}
}
fn decode_datagram(buf: &[u8], hdr: &CaHeader, src: SocketAddr) -> DatagramAction {
if hdr.cmmd == CA_PROTO_REPEATER_REGISTER {
if buf.len() <= CaHeader::SIZE {
return DatagramAction {
register: true,
fanout: None,
};
}
DatagramAction {
register: true,
fanout: Some(buf[CaHeader::SIZE..].to_vec()),
}
} else {
let mut data = buf.to_vec();
if hdr.cmmd == CA_PROTO_RSRV_IS_UP && hdr.available == 0 {
if let SocketAddr::V4(v4) = src {
let avail_offset = 12; data[avail_offset..avail_offset + 4].copy_from_slice(&v4.ip().octets());
}
}
DatagramAction {
register: false,
fanout: Some(data),
}
}
}
fn fan_out(clients: &mut HashMap<u16, RepeaterClient>, src: SocketAddr, data: &[u8], debug: u8) {
let mut dead = Vec::new();
for (port, client) in clients.iter() {
if client.addr == src {
continue;
}
if !client.send_message(data) {
if debug >= 1 {
eprintln!("Client on port {port} refused message");
}
if !client.verify() {
dead.push(*port);
} else if debug >= 2 {
eprintln!("Client on port {port} is alive");
}
} else if debug >= 2 {
eprintln!("Sent to port {port}");
}
}
for p in dead {
if debug >= 1 {
eprintln!("Deleted client on port {p}");
}
clients.remove(&p);
}
if debug >= 1 {
eprintln!("Verified {} active clients", clients.len());
}
}
fn join_beacon_multicast_groups(sock: &socket2::Socket) {
let list = epics_base_rs::runtime::env::get("EPICS_CAS_BEACON_ADDR_LIST")
.or_else(|| epics_base_rs::runtime::env::get("EPICS_CA_ADDR_LIST"));
let Some(list) = list else {
return;
};
for token in list.split_whitespace() {
let host = token.rsplit_once(':').map(|(h, _)| h).unwrap_or(token);
let Ok(addr) = host.parse::<Ipv4Addr>() else {
continue;
};
if !addr.is_multicast() {
continue;
}
if let Err(e) = sock.join_multicast_v4(&addr, &Ipv4Addr::UNSPECIFIED) {
tracing::warn!(group = %addr, error = %e,
"ca-repeater: IP_ADD_MEMBERSHIP failed");
}
}
}
const MAX_REPEATER_CLIENTS: usize = 1024;
fn register_client_debug(clients: &mut HashMap<u16, RepeaterClient>, src: SocketAddr, debug: u8) {
let port = src.port();
let was_registered = clients.contains_key(&port);
register_client(clients, src);
if !was_registered && debug >= 1 && clients.contains_key(&port) {
eprintln!("New client on port {port}");
eprintln!("Verified {} active clients", clients.len());
}
}
fn register_client(clients: &mut HashMap<u16, RepeaterClient>, src: SocketAddr) {
let port = src.port();
if let Some(client) = clients.get(&port) {
client.send_confirm();
return;
}
if clients.len() >= MAX_REPEATER_CLIENTS {
let dead: Vec<u16> = clients
.iter()
.filter(|(_, c)| !c.verify())
.map(|(p, _)| *p)
.collect();
for p in dead {
clients.remove(&p);
}
if clients.len() >= MAX_REPEATER_CLIENTS {
return;
}
}
let client = match RepeaterClient::new(src) {
Ok(c) => c,
Err(_) => return,
};
if !client.send_confirm() {
return;
}
clients.insert(port, client);
let noop = CaHeader::new(CA_PROTO_VERSION);
let noop_bytes = noop.to_bytes();
let mut dead = Vec::new();
for (p, c) in clients.iter() {
if *p == port {
continue;
}
if !c.send_message(&noop_bytes) {
if !c.verify() {
dead.push(*p);
}
}
}
for p in dead {
clients.remove(&p);
}
}
pub async fn ensure_repeater() {
if try_register().await.is_ok() {
return;
}
spawn_repeater();
epics_base_rs::runtime::task::sleep(std::time::Duration::from_millis(50)).await;
let _ = try_register().await;
}
async fn try_register() -> Result<(), ()> {
let socket = UdpSocket::bind("0.0.0.0:0").await.map_err(|_| ())?;
let _ = epics_base_rs::net::enable_so_rxq_ovfl_for_socket(&socket);
let local_ip = match socket.local_addr().ok() {
Some(SocketAddr::V4(v4)) => *v4.ip(),
_ => Ipv4Addr::LOCALHOST,
};
let mut hdr = CaHeader::new(CA_PROTO_REPEATER_REGISTER);
hdr.available = u32::from_be_bytes(local_ip.octets());
let repeater_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, repeater_port());
socket
.send_to(&hdr.to_bytes(), repeater_addr)
.await
.map_err(|_| ())?;
let mut buf = [0u8; 64];
let result = tokio::time::timeout(std::time::Duration::from_millis(200), async {
loop {
let (len, _, _drops) =
epics_base_rs::net::recv_from_with_drop_count_socket(&socket, &mut buf)
.await
.map_err(|_| ())?;
if len >= CaHeader::SIZE {
if let Ok(resp) = CaHeader::from_bytes(&buf[..len]) {
if resp.cmmd == CA_PROTO_REPEATER_CONFIRM {
return Ok::<(), ()>(());
}
}
}
}
})
.await;
match result {
Ok(Ok(())) => Ok(()),
_ => Err(()),
}
}
fn spawn_repeater() {
let exe = std::env::current_exe().unwrap_or_default();
let repeater_bin = exe.parent().map(|p| p.join("ca-repeater-rs"));
if let Some(ref bin) = repeater_bin {
if bin.exists() {
use std::process::{Command, Stdio};
if Command::new(bin)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.is_ok()
{
return;
}
}
}
std::thread::spawn(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("repeater runtime");
let _ = rt.block_on(run_repeater());
});
}
#[cfg(test)]
mod tests {
use super::*;
fn header_bytes(cmmd: u16, available: u32) -> Vec<u8> {
let mut h = CaHeader::new(cmmd);
h.available = available;
h.to_bytes().to_vec()
}
fn src_v4(a: u8, b: u8, c: u8, d: u8, port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(a, b, c, d), port))
}
#[test]
fn beacon_rewrites_zero_m_available_with_source_ip() {
let buf = header_bytes(CA_PROTO_RSRV_IS_UP, 0);
let hdr = CaHeader::from_bytes(&buf).unwrap();
let src = src_v4(10, 0, 0, 5, 4321);
let act = decode_datagram(&buf, &hdr, src);
assert!(!act.register);
let data = act.fanout.expect("beacon must be fanned out");
assert_eq!(&data[12..16], &[10, 0, 0, 5]);
}
#[test]
fn beacon_with_nonzero_m_available_is_unchanged() {
let buf = header_bytes(CA_PROTO_RSRV_IS_UP, 0x0a00_0006);
let hdr = CaHeader::from_bytes(&buf).unwrap();
let src = src_v4(192, 168, 1, 99, 5555);
let act = decode_datagram(&buf, &hdr, src);
assert!(!act.register);
let data = act.fanout.expect("beacon must be fanned out");
assert_eq!(&data[12..16], &0x0a00_0006u32.to_be_bytes());
}
#[test]
fn non_rsrv_non_register_message_is_not_rewritten() {
let buf = header_bytes(CA_PROTO_VERSION, 0);
let hdr = CaHeader::from_bytes(&buf).unwrap();
let src = src_v4(10, 0, 0, 5, 4321);
let act = decode_datagram(&buf, &hdr, src);
assert!(!act.register);
let data = act.fanout.expect("fan out");
assert_eq!(&data[12..16], &[0, 0, 0, 0]);
}
#[test]
fn bare_register_returns_register_only_no_fanout() {
let buf = header_bytes(CA_PROTO_REPEATER_REGISTER, 0);
let hdr = CaHeader::from_bytes(&buf).unwrap();
let src = src_v4(127, 0, 0, 1, 9000);
let act = decode_datagram(&buf, &hdr, src);
assert!(act.register);
assert!(
act.fanout.is_none(),
"bare REGISTER must not fan out anything"
);
}
#[test]
fn chained_register_plus_payload_strips_then_fans_out_remainder() {
let mut buf = header_bytes(CA_PROTO_REPEATER_REGISTER, 0);
let remainder = header_bytes(CA_PROTO_RSRV_IS_UP, 0);
buf.extend_from_slice(&remainder);
let hdr = CaHeader::from_bytes(&buf).unwrap();
let src = src_v4(10, 0, 0, 5, 5060);
let act = decode_datagram(&buf, &hdr, src);
assert!(act.register, "REGISTER must register the sender");
let data = act.fanout.expect("chained payload must fan out");
assert_eq!(data.len(), CaHeader::SIZE);
assert_eq!(&data, &remainder);
assert_eq!(&data[12..16], &[0, 0, 0, 0]);
}
#[test]
fn fan_out_skips_on_full_address_not_port_alone() {
let recv = StdUdpSocket::bind("127.0.0.1:0").expect("bind recv");
recv.set_read_timeout(Some(std::time::Duration::from_millis(750)))
.unwrap();
let local = recv.local_addr().unwrap();
let port = local.port();
let mut clients: HashMap<u16, RepeaterClient> = HashMap::new();
clients.insert(port, RepeaterClient::new(local).expect("client sock"));
let data = header_bytes(CA_PROTO_RSRV_IS_UP, 0x0a00_0005);
let server_src = src_v4(10, 0, 0, 5, port);
fan_out(&mut clients, server_src, &data, 0);
let mut buf = [0u8; 64];
let n = recv
.recv(&mut buf)
.expect("client with a coinciding port must still receive the beacon");
assert_eq!(
&buf[..n],
&data[..],
"fanned-out bytes must match the input"
);
fan_out(&mut clients, local, &data, 0);
let err = recv.recv(&mut buf).expect_err(
"a datagram from the client's own full address must be skipped, not reflected",
);
assert!(
matches!(
err.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
),
"expected a read timeout for the self-skip case, got {err:?}"
);
}
}