Skip to main content

spvirit_client/
search.rs

1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use dns_lookup::lookup_host;
7use get_if_addrs::{get_if_addrs, IfAddr};
8use socket2::{Domain, Protocol, Socket, Type};
9use tokio::io::AsyncWriteExt;
10use tokio::net::UdpSocket;
11use tracing::debug;
12
13use crate::auth::{default_authnz_host, default_authnz_user};
14use crate::transport::read_packet;
15use crate::types::{PvGetError, PvGetOptions};
16use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
17use spvirit_codec::spvirit_encode::{
18    encode_client_connection_validation, encode_search_request, ip_to_bytes,
19    socket_addr_from_pva_bytes,
20};
21
22#[derive(Clone, Copy, Debug)]
23pub struct SearchTarget {
24    pub target: IpAddr,
25    pub bind: IpAddr,
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub struct DiscoveredServer {
30    pub guid: [u8; 12],
31    pub tcp_addr: SocketAddr,
32}
33
34pub fn parse_addr_list(env: &str) -> Vec<IpAddr> {
35    env.split(|c| c == ',' || c == ' ' || c == '\t')
36        .filter(|s| !s.trim().is_empty())
37        .filter_map(|s| parse_search_target_ip(s.trim()))
38        .collect()
39}
40
41fn parse_search_target_ip(token: &str) -> Option<IpAddr> {
42    if token.is_empty() {
43        return None;
44    }
45
46    if let Ok(ip) = token.parse::<IpAddr>() {
47        return Some(ip);
48    }
49    if let Ok(sock) = token.parse::<SocketAddr>() {
50        return Some(sock.ip());
51    }
52
53    // Accept host:port where host may be a name or an IP literal.
54    // For IPv6 bracket notation [::1]:port, SocketAddr::parse above already handles it.
55    if let Some((host, port_str)) = token.rsplit_once(':') {
56        if !host.is_empty()
57            && !port_str.is_empty()
58            && port_str.chars().all(|c| c.is_ascii_digit())
59            && !host.contains(']')
60        {
61            if let Ok(ip) = host.parse::<IpAddr>() {
62                return Some(ip);
63            }
64            if let Ok(addrs) = lookup_host(host) {
65                // Prefer IPv4 for backward compat, fall back to first IPv6
66                let addrs: Vec<IpAddr> = addrs.collect();
67                if let Some(ip) = addrs.iter().find(|ip| ip.is_ipv4()).copied()
68                    .or_else(|| addrs.into_iter().next())
69                {
70                    return Some(ip);
71                }
72            }
73        }
74    }
75
76    if let Ok(addrs) = lookup_host(token) {
77        // Prefer IPv4, fall back to first IPv6
78        let addrs: Vec<IpAddr> = addrs.collect();
79        if let Some(ip) = addrs.iter().find(|ip| ip.is_ipv4()).copied()
80            .or_else(|| addrs.into_iter().next())
81        {
82            return Some(ip);
83        }
84    }
85
86    None
87}
88
89/// Return a default unspecified bind address matching the target's address family.
90fn unspecified_for(ip: IpAddr) -> IpAddr {
91    match ip {
92        IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
93        IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
94    }
95}
96
97pub fn build_search_targets(
98    search_addr: Option<IpAddr>,
99    bind_addr: Option<IpAddr>,
100) -> Vec<SearchTarget> {
101    // Explicit --search-addr overrides everything (single target).
102    if let Some(ip) = search_addr {
103        return vec![SearchTarget {
104            target: ip,
105            bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
106        }];
107    }
108
109    let mut targets = Vec::new();
110    let mut seen = HashSet::new();
111
112    // Addresses from EPICS_PVA_ADDR_LIST.
113    if let Ok(env) = std::env::var("EPICS_PVA_ADDR_LIST") {
114        for ip in parse_addr_list(&env) {
115            if seen.insert(ip) {
116                targets.push(SearchTarget {
117                    target: ip,
118                    bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
119                });
120            }
121        }
122    }
123
124    // Merge auto-discovered broadcast addresses unless explicitly disabled.
125    // This matches EPICS Base behaviour: ADDR_LIST + auto-broadcast combined.
126    if is_auto_addr_list_enabled() {
127        for t in build_auto_broadcast_targets() {
128            if seen.insert(t.target) {
129                targets.push(SearchTarget {
130                    target: t.target,
131                    bind: bind_addr.unwrap_or(t.bind),
132                });
133            }
134        }
135    }
136
137    targets
138}
139
140pub fn is_auto_addr_list_enabled() -> bool {
141    match std::env::var("EPICS_PVA_AUTO_ADDR_LIST") {
142        Ok(v) => {
143            let v = v.trim().to_ascii_uppercase();
144            v == "YES" || v == "Y" || v == "1" || v == "TRUE"
145        }
146        Err(_) => true,
147    }
148}
149
150fn ipv4_is_link_local(ip: Ipv4Addr) -> bool {
151    let octets = ip.octets();
152    octets[0] == 169 && octets[1] == 254
153}
154
155fn choose_default_bind_v4() -> Option<Ipv4Addr> {
156    let ifaces = get_if_addrs().ok()?;
157    for iface in ifaces {
158        if let IfAddr::V4(v4) = iface.addr {
159            let ip = v4.ip;
160            if ip.is_loopback() || ipv4_is_link_local(ip) {
161                continue;
162            }
163            return Some(ip);
164        }
165    }
166    None
167}
168
169fn choose_default_bind_v6() -> Option<Ipv6Addr> {
170    let ifaces = get_if_addrs().ok()?;
171    for iface in ifaces {
172        if let IfAddr::V6(v6) = iface.addr {
173            let ip = v6.ip;
174            if ip.is_loopback() {
175                continue;
176            }
177            // Skip link-local (fe80::/10) — not routable without scope id
178            let segs = ip.segments();
179            if segs[0] & 0xffc0 == 0xfe80 {
180                continue;
181            }
182            return Some(ip);
183        }
184    }
185    None
186}
187
188fn broadcast_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
189    let ip_u = u32::from(ip);
190    let mask_u = u32::from(netmask);
191    Ipv4Addr::from(ip_u | !mask_u)
192}
193
194fn discovery_target_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
195    let limited_broadcast = Ipv4Addr::new(255, 255, 255, 255);
196    if netmask == Ipv4Addr::new(255, 255, 255, 255) || netmask.is_unspecified() {
197        return limited_broadcast;
198    }
199    let directed = broadcast_for(ip, netmask);
200    if directed == ip {
201        limited_broadcast
202    } else {
203        directed
204    }
205}
206
207pub fn build_auto_broadcast_targets() -> Vec<SearchTarget> {
208    let mut targets = Vec::new();
209    let mut fallback_targets = Vec::new();
210    let mut fallback_seen = HashSet::new();
211    let mut added_v4_multicast = false;
212    let mut added_v6_multicast = false;
213    let ifaces = match get_if_addrs() {
214        Ok(v) => v,
215        Err(_) => return targets,
216    };
217    for iface in &ifaces {
218        if let IfAddr::V4(v4) = &iface.addr {
219            let ip = v4.ip;
220            if ip.is_loopback() || ipv4_is_link_local(ip) {
221                continue;
222            }
223            let bcast = discovery_target_for(ip, v4.netmask);
224            targets.push(SearchTarget {
225                target: IpAddr::V4(bcast),
226                bind: IpAddr::V4(ip),
227            });
228            // Also send to IPv4 multicast group (matching PVXS behaviour).
229            // Docker overlay networks may block broadcast but allow multicast.
230            targets.push(SearchTarget {
231                target: IpAddr::V4(PVA_MULTICAST_V4),
232                bind: IpAddr::V4(ip),
233            });
234            if fallback_seen.insert(IpAddr::V4(bcast)) {
235                fallback_targets.push(SearchTarget {
236                    target: IpAddr::V4(bcast),
237                    bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
238                });
239            }
240            if !added_v4_multicast {
241                added_v4_multicast = true;
242                fallback_targets.push(SearchTarget {
243                    target: IpAddr::V4(PVA_MULTICAST_V4),
244                    bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
245                });
246            }
247        }
248    }
249    // Add IPv6 multicast targets for each non-loopback, non-link-local v6 iface.
250    for iface in &ifaces {
251        if let IfAddr::V6(v6) = &iface.addr {
252            let ip = v6.ip;
253            if ip.is_loopback() {
254                continue;
255            }
256            let segs = ip.segments();
257            if segs[0] & 0xffc0 == 0xfe80 {
258                continue; // skip link-local
259            }
260            let multicast_target = IpAddr::V6(PVA_MULTICAST_V6);
261            targets.push(SearchTarget {
262                target: multicast_target,
263                bind: IpAddr::V6(ip),
264            });
265            if !added_v6_multicast {
266                added_v6_multicast = true;
267                fallback_targets.push(SearchTarget {
268                    target: multicast_target,
269                    bind: IpAddr::V6(Ipv6Addr::UNSPECIFIED),
270                });
271            }
272        }
273    }
274    targets.extend(fallback_targets);
275    targets
276}
277
278/// PVA multicast group (IPv4).
279const PVA_MULTICAST_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 128);
280
281/// PVA multicast group (IPv6 link-local, ff02::42:1).
282const PVA_MULTICAST_V6: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0x42, 1);
283
284/// Best-effort join the PVA multicast group appropriate for the bind address.
285fn join_multicast_any(socket: &std::net::UdpSocket, bind: IpAddr) {
286    match bind {
287        IpAddr::V4(iface) => {
288            let _ = socket.join_multicast_v4(&PVA_MULTICAST_V4, &iface);
289        }
290        IpAddr::V6(_) => {
291            // interface index 0 = OS picks the default interface
292            let _ = socket.join_multicast_v6(&PVA_MULTICAST_V6, 0);
293        }
294    }
295}
296
297fn decode_search_response_addr(addr: [u8; 16], port: u16, src: SocketAddr) -> SocketAddr {
298    socket_addr_from_pva_bytes(addr, port)
299        .filter(|a| !a.ip().is_unspecified())
300        .unwrap_or_else(|| SocketAddr::new(src.ip(), port))
301}
302
303fn normalize_discovered_servers(items: Vec<DiscoveredServer>) -> Vec<DiscoveredServer> {
304    let mut seen = HashSet::new();
305    let mut out = Vec::new();
306    for item in items {
307        if seen.insert((item.guid, item.tcp_addr)) {
308            out.push(item);
309        }
310    }
311    out.sort_by(|a, b| a.tcp_addr.to_string().cmp(&b.tcp_addr.to_string()));
312    out
313}
314
315/// Create a UDP socket with SO_REUSEADDR set (matching PVXS behaviour),
316/// allowing multiple processes to share the search port.
317///
318/// On Windows SO_REUSEADDR has different (unsafe) semantics — it allows
319/// a second socket to steal an actively-used port — so we only enable it
320/// on Unix where it merely permits rebinding during TIME_WAIT.
321fn bind_udp_reuse(addr: SocketAddr) -> std::io::Result<std::net::UdpSocket> {
322    let domain = if addr.is_ipv4() {
323        Domain::IPV4
324    } else {
325        Domain::IPV6
326    };
327    let sock = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
328    #[cfg(unix)]
329    sock.set_reuse_address(true)?;
330    sock.set_nonblocking(true)?;
331    sock.bind(&addr.into())?;
332    Ok(sock.into())
333}
334
335pub async fn search_pv(
336    pv_name: &str,
337    udp_port: u16,
338    timeout_dur: Duration,
339    targets: &[SearchTarget],
340    debug_enabled: bool,
341) -> Result<SocketAddr, PvGetError> {
342    if targets.is_empty() {
343        return Err(PvGetError::Search("no search targets"));
344    }
345
346    let now = std::time::SystemTime::now()
347        .duration_since(std::time::UNIX_EPOCH)
348        .unwrap_or_default();
349    let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
350    let cid = seq ^ 0x9E37_79B9;
351
352    let mut last_io_error: Option<std::io::Error> = None;
353    let deadline = tokio::time::Instant::now() + timeout_dur;
354
355    // Group targets by bind address so we can share a socket per bind.
356    let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
357    for t in targets {
358        if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
359            group.1.push(t.target);
360        } else {
361            bind_groups.push((t.bind, vec![t.target]));
362        }
363    }
364
365    // Open sockets and send to all targets first, then collect responses.
366    // Store (socket, message, destinations) for retransmission.
367    let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
368
369    for (bind_ip, group_targets) in &bind_groups {
370        let bind_addr = SocketAddr::new(*bind_ip, udp_port);
371        let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
372            Ok(sock) => (sock, bind_addr),
373            Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
374                let fallback = SocketAddr::new(*bind_ip, 0);
375                match bind_udp_reuse(fallback) {
376                    Ok(sock) => {
377                        let actual = sock.local_addr().unwrap_or(fallback);
378                        if debug_enabled {
379                            debug!(
380                                "pva search bind={} failed (in use), fallback bind={}",
381                                bind_addr, actual
382                            );
383                        }
384                        (sock, actual)
385                    }
386                    Err(fallback_err) => {
387                        if debug_enabled {
388                            debug!(
389                                "pva search skipping bind={} step=bind-fallback kind={:?} err={}",
390                                bind_addr,
391                                fallback_err.kind(),
392                                fallback_err
393                            );
394                        }
395                        last_io_error = Some(fallback_err);
396                        continue;
397                    }
398                }
399            }
400            Err(err) => {
401                if debug_enabled {
402                    debug!(
403                        "pva search skipping bind={} step=bind kind={:?} err={}",
404                        bind_addr,
405                        err.kind(),
406                        err
407                    );
408                }
409                last_io_error = Some(err);
410                continue;
411            }
412        };
413        if let Err(err) = std_sock.set_broadcast(true) {
414            if debug_enabled {
415                debug!(
416                    "pva search skipping bind={} step=set_broadcast kind={:?} err={}",
417                    bind_addr,
418                    err.kind(),
419                    err
420                );
421            }
422            last_io_error = Some(err);
423            continue;
424        }
425
426        join_multicast_any(&std_sock, *bind_ip);
427
428        let reply_addr = ip_to_bytes(*bind_ip);
429        let reply_port = match std_sock.local_addr() {
430            Ok(addr) => addr.port(),
431            Err(err) => {
432                if debug_enabled {
433                    debug!(
434                        "pva search skipping bind={} step=local_addr kind={:?} err={}",
435                        bind_addr,
436                        err.kind(),
437                        err
438                    );
439                }
440                last_io_error = Some(err);
441                continue;
442            }
443        };
444        let requests = [(cid, pv_name)];
445        let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &requests, 2, false);
446
447        let socket = match UdpSocket::from_std(std_sock) {
448            Ok(socket) => socket,
449            Err(err) => {
450                if debug_enabled {
451                    debug!(
452                        "pva search skipping bind={} step=from_std kind={:?} err={}",
453                        bind_addr,
454                        err.kind(),
455                        err
456                    );
457                }
458                last_io_error = Some(err);
459                continue;
460            }
461        };
462
463        let dests: Vec<SocketAddr> = group_targets
464            .iter()
465            .map(|ip| SocketAddr::new(*ip, udp_port))
466            .collect();
467
468        // Send to every target in this bind group immediately.
469        for dest in &dests {
470            if debug_enabled {
471                debug!(
472                    "pva search bind={} target={} server_port={} reply_port={}",
473                    actual_bind_addr, dest.ip(), udp_port, reply_port
474                );
475                debug!("pva search seq={} cid={}", seq, cid);
476                debug!("pva search send {} bytes to {}", msg.len(), dest);
477            }
478            if let Err(err) = socket.send_to(&msg, dest).await {
479                if debug_enabled {
480                    debug!(
481                        "pva search send_to target={} kind={:?} err={}",
482                        dest,
483                        err.kind(),
484                        err
485                    );
486                }
487                last_io_error = Some(err);
488            }
489        }
490
491        socket_info.push((Arc::new(socket), msg, dests));
492    }
493
494    if socket_info.is_empty() {
495        if let Some(err) = last_io_error {
496            return Err(PvGetError::Io(err));
497        }
498        return Err(PvGetError::Timeout("search response"));
499    }
500
501    // Spawn a receiver task per socket that forwards packets into a shared channel.
502    let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
503    for (sock, _, _) in &socket_info {
504        let sock = Arc::clone(sock);
505        let tx = tx.clone();
506        tokio::spawn(async move {
507            loop {
508                let mut buf = vec![0u8; 2048];
509                match sock.recv_from(&mut buf).await {
510                    Ok((len, src)) => {
511                        buf.truncate(len);
512                        if tx.send((buf, src)).await.is_err() {
513                            break;
514                        }
515                    }
516                    Err(_) => break,
517                }
518            }
519        });
520    }
521    drop(tx); // Only spawned tasks hold senders; channel closes when they exit.
522
523    // Retransmit schedule: exponential backoff from start.
524    let retransmit_offsets = [100u64, 500, 1000, 2000];
525    let start = tokio::time::Instant::now();
526    let mut next_retransmit = 0usize;
527
528    loop {
529        // Compute the next wake-up: either the next retransmit or the deadline.
530        let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
531            start + Duration::from_millis(retransmit_offsets[next_retransmit])
532        } else {
533            deadline
534        };
535        let wake_at = next_retransmit_at.min(deadline);
536
537        tokio::select! {
538            recv = rx.recv() => {
539                let Some((buf, src)) = recv else { break };
540                let mut pkt = PvaPacket::new(&buf);
541                let cmd = pkt
542                    .decode_payload()
543                    .ok_or(PvGetError::Search("failed to decode search response"))?;
544                if let PvaPacketCommand::SearchResponse(payload) = cmd {
545                    if debug_enabled {
546                        debug!(
547                            "pva search response found={} cids={:?} addr={:?} port={}",
548                            payload.found, payload.cids, payload.addr, payload.port
549                        );
550                    }
551                    if payload.seq != seq {
552                        continue;
553                    }
554                    if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
555                        continue;
556                    }
557                    if !payload.found {
558                        continue;
559                    }
560                    if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
561                        continue;
562                    }
563
564                    let addr = decode_search_response_addr(payload.addr, payload.port, src);
565                    if debug_enabled {
566                        debug!("pva search response from {}", addr);
567                    }
568                    return Ok(addr);
569                }
570            }
571            _ = tokio::time::sleep_until(wake_at) => {
572                if tokio::time::Instant::now() >= deadline {
573                    break;
574                }
575                // Retransmit to all targets on all sockets.
576                if next_retransmit < retransmit_offsets.len() {
577                    if debug_enabled {
578                        debug!("pva search retransmit round {}", next_retransmit + 1);
579                    }
580                    for (sock, msg, dests) in &socket_info {
581                        for dest in dests {
582                            let _ = sock.send_to(msg, dest).await;
583                        }
584                    }
585                    next_retransmit += 1;
586                }
587            }
588        }
589    }
590
591    Err(PvGetError::Timeout("search response"))
592}
593
594pub fn default_bind_ip() -> Option<IpAddr> {
595    choose_default_bind_v4()
596        .map(IpAddr::V4)
597        .or_else(|| choose_default_bind_v6().map(IpAddr::V6))
598}
599
600/// Parse `EPICS_PVA_NAME_SERVERS` value into socket addresses.
601/// Accepts space/comma separated entries: `host:port`, `ip`, `hostname`
602/// (port defaults to 5075).
603pub fn parse_name_servers(env_val: &str) -> Vec<SocketAddr> {
604    let mut out = Vec::new();
605    for token in env_val.split(|c| c == ',' || c == ' ' || c == '\t') {
606        let token = token.trim();
607        if token.is_empty() {
608            continue;
609        }
610        if let Ok(addr) = token.parse::<SocketAddr>() {
611            out.push(addr);
612            continue;
613        }
614        if let Ok(ip) = token.parse::<IpAddr>() {
615            out.push(SocketAddr::new(ip, 5075));
616            continue;
617        }
618        use std::net::ToSocketAddrs;
619        if let Ok(mut addrs) = token.to_socket_addrs() {
620            if let Some(addr) = addrs.next() {
621                out.push(addr);
622                continue;
623            }
624        }
625        let with_port = format!("{}:5075", token);
626        if let Ok(mut addrs) = with_port.to_socket_addrs() {
627            if let Some(addr) = addrs.next() {
628                out.push(addr);
629            }
630        }
631    }
632    out
633}
634
635/// Build a minimal PVA ConnectionValidation response for name server search.
636fn encode_search_validation(version: u8, is_be: bool) -> Vec<u8> {
637    let user = default_authnz_user();
638    let host = default_authnz_host();
639    encode_client_connection_validation(87_040, 32_767, 0, "ca", &user, &host, version, is_be)
640}
641
642/// Search for a PV via a TCP connection to a PVA name server.
643///
644/// Connects to the name server, performs the PVA handshake, sends a search
645/// request over TCP, and returns the server address from the search response.
646pub async fn search_pv_tcp(
647    pv_name: &str,
648    name_server: SocketAddr,
649    timeout_dur: Duration,
650    debug_enabled: bool,
651) -> Result<SocketAddr, PvGetError> {
652    let deadline = tokio::time::Instant::now() + timeout_dur;
653
654    let mut stream =
655        tokio::time::timeout(timeout_dur, tokio::net::TcpStream::connect(name_server))
656            .await
657            .map_err(|_| PvGetError::Timeout("name server connect"))??;
658
659    let mut version = 2u8;
660    let mut is_be = false;
661
662    // Read SET_BYTE_ORDER + ConnectionValidation from name server.
663    for _ in 0..2 {
664        let now = tokio::time::Instant::now();
665        if now >= deadline {
666            return Err(PvGetError::Timeout("name server handshake"));
667        }
668        let remaining = deadline - now;
669        if let Ok(bytes) = read_packet(&mut stream, remaining).await {
670            let mut pkt = PvaPacket::new(&bytes);
671            if let Some(cmd) = pkt.decode_payload() {
672                match cmd {
673                    PvaPacketCommand::Control(payload) => {
674                        if payload.command == 2 {
675                            is_be = pkt.header.flags.is_msb;
676                        }
677                    }
678                    PvaPacketCommand::ConnectionValidation(_) => {
679                        version = pkt.header.version;
680                        is_be = pkt.header.flags.is_msb;
681                    }
682                    _ => {}
683                }
684            }
685        }
686    }
687
688    let validation = encode_search_validation(version, is_be);
689    stream.write_all(&validation).await?;
690
691    // Wait for ConnectionValidated.
692    loop {
693        let now = tokio::time::Instant::now();
694        if now >= deadline {
695            return Err(PvGetError::Timeout("name server validated"));
696        }
697        let remaining = deadline - now;
698        let bytes = read_packet(&mut stream, remaining).await?;
699        let mut pkt = PvaPacket::new(&bytes);
700        if let Some(cmd) = pkt.decode_payload() {
701            if matches!(cmd, PvaPacketCommand::ConnectionValidated(_)) {
702                break;
703            }
704        }
705    }
706
707    // Send search request over TCP.
708    let now_ts = std::time::SystemTime::now()
709        .duration_since(std::time::UNIX_EPOCH)
710        .unwrap_or_default();
711    let seq = (now_ts.as_nanos() as u32).wrapping_add(std::process::id());
712    let cid = seq ^ 0x9E37_79B9;
713    let requests = [(cid, pv_name)];
714    let msg = encode_search_request(seq, 0x80, 0, [0u8; 16], &requests, version, is_be);
715    stream.write_all(&msg).await?;
716
717    if debug_enabled {
718        debug!(
719            "pva tcp search sent to name_server={} pv={}",
720            name_server, pv_name
721        );
722    }
723
724    // Read search response.
725    loop {
726        let now = tokio::time::Instant::now();
727        if now >= deadline {
728            return Err(PvGetError::Timeout("name server search response"));
729        }
730        let remaining = deadline - now;
731        let bytes = read_packet(&mut stream, remaining).await?;
732        let mut pkt = PvaPacket::new(&bytes);
733        if let Some(cmd) = pkt.decode_payload() {
734            if let PvaPacketCommand::SearchResponse(payload) = cmd {
735                if !payload.found {
736                    continue;
737                }
738                if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
739                    continue;
740                }
741                let addr =
742                    decode_search_response_addr(payload.addr, payload.port, name_server);
743                if debug_enabled {
744                    debug!(
745                        "pva tcp search response from name_server={}: {}",
746                        name_server, addr
747                    );
748                }
749                return Ok(addr);
750            }
751        }
752    }
753}
754
755/// Resolve the PVA server for a PV using name servers (TCP) and/or UDP search.
756///
757/// - If `opts.server_addr` is set, returns it directly.
758/// - Tries each name server from `opts.name_servers` and `EPICS_PVA_NAME_SERVERS`
759///   via TCP search.
760/// - Falls back to UDP search using `build_search_targets()`.
761pub async fn resolve_pv_server(opts: &PvGetOptions) -> Result<SocketAddr, PvGetError> {
762    if let Some(addr) = opts.server_addr {
763        return Ok(addr);
764    }
765
766    let mut name_servers = opts.name_servers.clone();
767    if let Ok(env) = std::env::var("EPICS_PVA_NAME_SERVERS") {
768        name_servers.extend(parse_name_servers(&env));
769    }
770
771    let no_broadcast = opts.no_broadcast;
772
773    // Fail fast when no search strategy is available.
774    if no_broadcast && name_servers.is_empty() {
775        return Err(PvGetError::Search(
776            "no search strategy: specify --name-server or --server when using --no-broadcast",
777        ));
778    }
779
780    // Launch all search strategies concurrently — TCP name servers + UDP broadcast.
781    // Return the first successful result.
782    let targets = build_search_targets(opts.search_addr, opts.bind_addr);
783
784    let pv = opts.pv_name.clone();
785    let timeout_dur = opts.timeout;
786    let debug_enabled = opts.debug;
787    let udp_port = opts.udp_port;
788
789    let mut set = tokio::task::JoinSet::new();
790
791    for ns in name_servers {
792        let pv = pv.clone();
793        set.spawn(async move {
794            let addr = search_pv_tcp(&pv, ns, timeout_dur, debug_enabled).await?;
795            Ok::<SocketAddr, PvGetError>(addr)
796        });
797    }
798
799    if !no_broadcast {
800        let pv = pv.clone();
801        let targets = targets.clone();
802        set.spawn(async move {
803            let addr = search_pv(&pv, udp_port, timeout_dur, &targets, debug_enabled).await?;
804            Ok(addr)
805        });
806    }
807
808    let mut last_err = None;
809    while let Some(result) = set.join_next().await {
810        match result {
811            Ok(Ok(addr)) => {
812                set.abort_all();
813                return Ok(addr);
814            }
815            Ok(Err(e)) => {
816                if debug_enabled {
817                    debug!("pva search strategy failed: {}", e);
818                }
819                last_err = Some(e);
820            }
821            Err(join_err) => {
822                if debug_enabled {
823                    debug!("pva search task panicked: {}", join_err);
824                }
825            }
826        }
827    }
828
829    Err(last_err.unwrap_or(PvGetError::Timeout("search response")))
830}
831
832pub async fn discover_servers(
833    udp_port: u16,
834    timeout_dur: Duration,
835    targets: &[SearchTarget],
836    debug_enabled: bool,
837) -> Result<Vec<DiscoveredServer>, PvGetError> {
838    if targets.is_empty() {
839        return Err(PvGetError::Search("no search targets"));
840    }
841
842    let now = std::time::SystemTime::now()
843        .duration_since(std::time::UNIX_EPOCH)
844        .unwrap_or_default();
845    let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
846
847    let mut found: Vec<DiscoveredServer> = Vec::new();
848    let mut last_io_error: Option<std::io::Error> = None;
849    let deadline = tokio::time::Instant::now() + timeout_dur;
850
851    // Group targets by bind address so we can share a socket per bind.
852    let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
853    for t in targets {
854        if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
855            group.1.push(t.target);
856        } else {
857            bind_groups.push((t.bind, vec![t.target]));
858        }
859    }
860
861    // Open sockets and send to all targets first, then collect responses.
862    // Store (socket, message, destinations) for retransmission.
863    let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
864
865    for (bind_ip, group_targets) in &bind_groups {
866        let bind_addr = SocketAddr::new(*bind_ip, udp_port);
867        let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
868            Ok(sock) => (sock, bind_addr),
869            Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
870                let fallback = SocketAddr::new(*bind_ip, 0);
871                match bind_udp_reuse(fallback) {
872                    Ok(sock) => {
873                        let actual = sock.local_addr().unwrap_or(fallback);
874                        if debug_enabled {
875                            debug!(
876                                "pva discover bind={} failed (in use), fallback bind={}",
877                                bind_addr, actual
878                            );
879                        }
880                        (sock, actual)
881                    }
882                    Err(fallback_err) => {
883                        if debug_enabled {
884                            debug!(
885                                "pva discover skipping bind={} step=bind-fallback kind={:?} err={}",
886                                bind_addr,
887                                fallback_err.kind(),
888                                fallback_err
889                            );
890                        }
891                        last_io_error = Some(fallback_err);
892                        continue;
893                    }
894                }
895            }
896            Err(err) => {
897                if debug_enabled {
898                    debug!(
899                        "pva discover skipping bind={} step=bind kind={:?} err={}",
900                        bind_addr,
901                        err.kind(),
902                        err
903                    );
904                }
905                last_io_error = Some(err);
906                continue;
907            }
908        };
909        if let Err(err) = std_sock.set_broadcast(true) {
910            if debug_enabled {
911                debug!(
912                    "pva discover skipping bind={} step=set_broadcast kind={:?} err={}",
913                    bind_addr,
914                    err.kind(),
915                    err
916                );
917            }
918            last_io_error = Some(err);
919            continue;
920        }
921
922        join_multicast_any(&std_sock, *bind_ip);
923
924        let reply_addr = ip_to_bytes(*bind_ip);
925        let reply_port = match std_sock.local_addr() {
926            Ok(addr) => addr.port(),
927            Err(err) => {
928                if debug_enabled {
929                    debug!(
930                        "pva discover skipping bind={} step=local_addr kind={:?} err={}",
931                        bind_addr,
932                        err.kind(),
933                        err
934                    );
935                }
936                last_io_error = Some(err);
937                continue;
938            }
939        };
940        let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &[], 2, false);
941
942        let socket = match UdpSocket::from_std(std_sock) {
943            Ok(socket) => socket,
944            Err(err) => {
945                if debug_enabled {
946                    debug!(
947                        "pva discover skipping bind={} step=from_std kind={:?} err={}",
948                        bind_addr,
949                        err.kind(),
950                        err
951                    );
952                }
953                last_io_error = Some(err);
954                continue;
955            }
956        };
957
958        let dests: Vec<SocketAddr> = group_targets
959            .iter()
960            .map(|ip| SocketAddr::new(*ip, udp_port))
961            .collect();
962
963        // Send to every target in this bind group immediately.
964        for dest in &dests {
965            if debug_enabled {
966                debug!(
967                    "pva discover bind={} target={} server_port={} reply_port={} seq={}",
968                    actual_bind_addr, dest.ip(), udp_port, reply_port, seq
969                );
970            }
971            if let Err(err) = socket.send_to(&msg, dest).await {
972                if debug_enabled {
973                    debug!(
974                        "pva discover send_to target={} kind={:?} err={}",
975                        dest,
976                        err.kind(),
977                        err
978                    );
979                }
980                last_io_error = Some(err);
981            }
982        }
983
984        socket_info.push((Arc::new(socket), msg, dests));
985    }
986
987    if socket_info.is_empty() {
988        if let Some(err) = last_io_error {
989            return Err(PvGetError::Io(err));
990        }
991        return Err(PvGetError::Search("no search targets"));
992    }
993
994    // Spawn a receiver task per socket that forwards packets into a shared channel.
995    let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
996    for (sock, _, _) in &socket_info {
997        let sock = Arc::clone(sock);
998        let tx = tx.clone();
999        tokio::spawn(async move {
1000            loop {
1001                let mut buf = vec![0u8; 2048];
1002                match sock.recv_from(&mut buf).await {
1003                    Ok((len, src)) => {
1004                        buf.truncate(len);
1005                        if tx.send((buf, src)).await.is_err() {
1006                            break;
1007                        }
1008                    }
1009                    Err(_) => break,
1010                }
1011            }
1012        });
1013    }
1014    drop(tx); // Only spawned tasks hold senders; channel closes when they exit.
1015
1016    // Retransmit schedule: exponential backoff from start.
1017    let retransmit_offsets = [100u64, 500, 1000, 2000];
1018    let start = tokio::time::Instant::now();
1019    let mut next_retransmit = 0usize;
1020
1021    loop {
1022        // Compute the next wake-up: either the next retransmit or the deadline.
1023        let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
1024            start + Duration::from_millis(retransmit_offsets[next_retransmit])
1025        } else {
1026            deadline
1027        };
1028        let wake_at = next_retransmit_at.min(deadline);
1029
1030        tokio::select! {
1031            recv = rx.recv() => {
1032                let Some((buf, src)) = recv else { break };
1033                let mut pkt = PvaPacket::new(&buf);
1034                let Some(cmd) = pkt.decode_payload() else {
1035                    continue;
1036                };
1037                if let PvaPacketCommand::SearchResponse(payload) = cmd {
1038                    if payload.seq != seq {
1039                        continue;
1040                    }
1041                    if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
1042                        continue;
1043                    }
1044                    let tcp_addr = decode_search_response_addr(payload.addr, payload.port, src);
1045                    found.push(DiscoveredServer {
1046                        guid: payload.guid,
1047                        tcp_addr,
1048                    });
1049                }
1050            }
1051            _ = tokio::time::sleep_until(wake_at) => {
1052                if tokio::time::Instant::now() >= deadline {
1053                    break;
1054                }
1055                // Retransmit to all targets on all sockets.
1056                if next_retransmit < retransmit_offsets.len() {
1057                    if debug_enabled {
1058                        debug!("pva discover retransmit round {}", next_retransmit + 1);
1059                    }
1060                    for (sock, msg, dests) in &socket_info {
1061                        for dest in dests {
1062                            let _ = sock.send_to(msg, dest).await;
1063                        }
1064                    }
1065                    next_retransmit += 1;
1066                }
1067            }
1068        }
1069    }
1070
1071    Ok(normalize_discovered_servers(found))
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077    use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
1078
1079    #[test]
1080    fn encode_decode_search_request_roundtrip() {
1081        let seq = 1234;
1082        let cid = 42;
1083        let port = 5076;
1084        let pv_name = "TEST:PV";
1085        let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 20)));
1086        let requests = [(cid, pv_name)];
1087        let msg = encode_search_request(seq, 0x81, port, reply_addr, &requests, 2, false);
1088        let mut pkt = PvaPacket::new(&msg);
1089        let cmd = pkt.decode_payload().expect("decoded");
1090        match cmd {
1091            PvaPacketCommand::Search(payload) => {
1092                assert_eq!(payload.seq, seq);
1093                assert_eq!(payload.mask, 0x81);
1094                assert_eq!(payload.addr, reply_addr);
1095                assert_eq!(payload.port, port);
1096                assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1097                assert_eq!(payload.pv_requests.len(), 1);
1098                assert_eq!(payload.pv_requests[0].0, cid);
1099                assert_eq!(payload.pv_requests[0].1, pv_name.to_string());
1100            }
1101            other => panic!("unexpected decode: {:?}", other),
1102        }
1103    }
1104
1105    #[test]
1106    fn encode_decode_server_discovery_request_roundtrip() {
1107        let seq = 4321;
1108        let port = 5076;
1109        let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(10, 20, 30, 40)));
1110        let msg = encode_search_request(seq, 0x81, port, reply_addr, &[], 2, false);
1111        let mut pkt = PvaPacket::new(&msg);
1112        let cmd = pkt.decode_payload().expect("decoded");
1113        match cmd {
1114            PvaPacketCommand::Search(payload) => {
1115                assert_eq!(payload.seq, seq);
1116                assert_eq!(payload.pv_requests.len(), 0);
1117                assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1118            }
1119            other => panic!("unexpected decode: {:?}", other),
1120        }
1121    }
1122
1123    #[test]
1124    fn normalize_discovered_servers_deduplicates_by_guid_and_addr() {
1125        let guid = [1u8; 12];
1126        let s1 = DiscoveredServer {
1127            guid,
1128            tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1129        };
1130        let s2 = DiscoveredServer {
1131            guid,
1132            tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1133        };
1134        let s3 = DiscoveredServer {
1135            guid: [2u8; 12],
1136            tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1137        };
1138        let normalized = normalize_discovered_servers(vec![s1, s2, s3]);
1139        assert_eq!(normalized.len(), 2);
1140    }
1141
1142    #[test]
1143    fn parse_addr_list_accepts_ip_and_ip_port() {
1144        let items = parse_addr_list("192.168.1.10 10.0.0.1:5076");
1145        assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10))));
1146        assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
1147    }
1148
1149    #[test]
1150    fn discovery_target_falls_back_to_limited_broadcast_for_invalid_netmask() {
1151        let ip = Ipv4Addr::new(130, 246, 90, 92);
1152        assert_eq!(
1153            discovery_target_for(ip, Ipv4Addr::new(255, 255, 255, 255)),
1154            Ipv4Addr::new(255, 255, 255, 255)
1155        );
1156        assert_eq!(
1157            discovery_target_for(ip, Ipv4Addr::new(0, 0, 0, 0)),
1158            Ipv4Addr::new(255, 255, 255, 255)
1159        );
1160    }
1161
1162    #[test]
1163    fn discovery_target_uses_directed_broadcast_for_normal_subnet() {
1164        let ip = Ipv4Addr::new(192, 168, 56, 1);
1165        let netmask = Ipv4Addr::new(255, 255, 255, 0);
1166        assert_eq!(
1167            discovery_target_for(ip, netmask),
1168            Ipv4Addr::new(192, 168, 56, 255)
1169        );
1170    }
1171
1172    #[test]
1173    fn parse_name_servers_ip_with_port() {
1174        let addrs = parse_name_servers("192.168.1.10:5075");
1175        assert_eq!(addrs, vec!["192.168.1.10:5075".parse::<SocketAddr>().unwrap()]);
1176    }
1177
1178    #[test]
1179    fn parse_name_servers_ip_without_port_defaults_to_5075() {
1180        let addrs = parse_name_servers("10.0.0.1");
1181        assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5075)]);
1182    }
1183
1184    #[test]
1185    fn parse_name_servers_multiple_comma_separated() {
1186        let addrs = parse_name_servers("10.0.0.1:5075,10.0.0.2:9876");
1187        assert_eq!(addrs.len(), 2);
1188        assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1189        assert_eq!(addrs[1], "10.0.0.2:9876".parse::<SocketAddr>().unwrap());
1190    }
1191
1192    #[test]
1193    fn parse_name_servers_multiple_space_separated() {
1194        let addrs = parse_name_servers("10.0.0.1 10.0.0.2:5075");
1195        assert_eq!(addrs.len(), 2);
1196        assert_eq!(addrs[0], SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5075));
1197        assert_eq!(addrs[1], "10.0.0.2:5075".parse::<SocketAddr>().unwrap());
1198    }
1199
1200    #[test]
1201    fn parse_name_servers_empty_string() {
1202        let addrs = parse_name_servers("");
1203        assert!(addrs.is_empty());
1204    }
1205
1206    #[test]
1207    fn parse_name_servers_whitespace_only() {
1208        let addrs = parse_name_servers("  \t  ");
1209        assert!(addrs.is_empty());
1210    }
1211
1212    #[test]
1213    fn parse_name_servers_mixed_separators() {
1214        let addrs = parse_name_servers("10.0.0.1:5075, 10.0.0.2  ,  10.0.0.3:9999");
1215        assert_eq!(addrs.len(), 3);
1216        assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1217        assert_eq!(addrs[1], SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 5075));
1218        assert_eq!(addrs[2], "10.0.0.3:9999".parse::<SocketAddr>().unwrap());
1219    }
1220
1221    #[test]
1222    fn parse_name_servers_ipv6_with_port() {
1223        let addrs = parse_name_servers("[::1]:5075");
1224        assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]);
1225    }
1226
1227    #[test]
1228    fn parse_name_servers_ipv6_without_port() {
1229        let addrs = parse_name_servers("::1");
1230        assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]);
1231    }
1232
1233    #[test]
1234    fn decode_search_response_addr_falls_back_to_udp_source_when_unspecified() {
1235        let src: SocketAddr = "192.168.1.20:5076".parse().unwrap();
1236        let decoded = decode_search_response_addr([0u8; 16], 5075, src);
1237        assert_eq!(decoded, "192.168.1.20:5075".parse().unwrap());
1238    }
1239}