1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::sync::Arc;
4use std::time::Duration;
5
6use dns_lookup::lookup_host;
7use get_if_addrs::{get_if_addrs, IfAddr};
8use socket2::{Domain, Protocol, Socket, Type};
9use tokio::io::AsyncWriteExt;
10use tokio::net::UdpSocket;
11use tracing::debug;
12
13use crate::auth::{default_authnz_host, default_authnz_user};
14use crate::transport::read_packet;
15use crate::types::{PvGetError, PvGetOptions};
16use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
17use spvirit_codec::spvirit_encode::{
18 encode_client_connection_validation, encode_search_request, ip_to_bytes,
19 socket_addr_from_pva_bytes,
20};
21
22#[derive(Clone, Copy, Debug)]
23pub struct SearchTarget {
24 pub target: IpAddr,
25 pub bind: IpAddr,
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub struct DiscoveredServer {
30 pub guid: [u8; 12],
31 pub tcp_addr: SocketAddr,
32}
33
34pub fn parse_addr_list(env: &str) -> Vec<IpAddr> {
35 env.split(|c| c == ',' || c == ' ' || c == '\t')
36 .filter(|s| !s.trim().is_empty())
37 .filter_map(|s| parse_search_target_ip(s.trim()))
38 .collect()
39}
40
41fn parse_search_target_ip(token: &str) -> Option<IpAddr> {
42 if token.is_empty() {
43 return None;
44 }
45
46 if let Ok(ip) = token.parse::<IpAddr>() {
47 return Some(ip);
48 }
49 if let Ok(sock) = token.parse::<SocketAddr>() {
50 return Some(sock.ip());
51 }
52
53 if let Some((host, port_str)) = token.rsplit_once(':') {
56 if !host.is_empty()
57 && !port_str.is_empty()
58 && port_str.chars().all(|c| c.is_ascii_digit())
59 && !host.contains(']')
60 {
61 if let Ok(ip) = host.parse::<IpAddr>() {
62 return Some(ip);
63 }
64 if let Ok(addrs) = lookup_host(host) {
65 let addrs: Vec<IpAddr> = addrs.collect();
67 if let Some(ip) = addrs.iter().find(|ip| ip.is_ipv4()).copied()
68 .or_else(|| addrs.into_iter().next())
69 {
70 return Some(ip);
71 }
72 }
73 }
74 }
75
76 if let Ok(addrs) = lookup_host(token) {
77 let addrs: Vec<IpAddr> = addrs.collect();
79 if let Some(ip) = addrs.iter().find(|ip| ip.is_ipv4()).copied()
80 .or_else(|| addrs.into_iter().next())
81 {
82 return Some(ip);
83 }
84 }
85
86 None
87}
88
89fn unspecified_for(ip: IpAddr) -> IpAddr {
91 match ip {
92 IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
93 IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
94 }
95}
96
97pub fn build_search_targets(
98 search_addr: Option<IpAddr>,
99 bind_addr: Option<IpAddr>,
100) -> Vec<SearchTarget> {
101 if let Some(ip) = search_addr {
103 return vec![SearchTarget {
104 target: ip,
105 bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
106 }];
107 }
108
109 let mut targets = Vec::new();
110 let mut seen = HashSet::new();
111
112 if let Ok(env) = std::env::var("EPICS_PVA_ADDR_LIST") {
114 for ip in parse_addr_list(&env) {
115 if seen.insert(ip) {
116 targets.push(SearchTarget {
117 target: ip,
118 bind: bind_addr.unwrap_or_else(|| unspecified_for(ip)),
119 });
120 }
121 }
122 }
123
124 if is_auto_addr_list_enabled() {
127 for t in build_auto_broadcast_targets() {
128 if seen.insert(t.target) {
129 targets.push(SearchTarget {
130 target: t.target,
131 bind: bind_addr.unwrap_or(t.bind),
132 });
133 }
134 }
135 }
136
137 targets
138}
139
140pub fn is_auto_addr_list_enabled() -> bool {
141 match std::env::var("EPICS_PVA_AUTO_ADDR_LIST") {
142 Ok(v) => {
143 let v = v.trim().to_ascii_uppercase();
144 v == "YES" || v == "Y" || v == "1" || v == "TRUE"
145 }
146 Err(_) => true,
147 }
148}
149
150fn ipv4_is_link_local(ip: Ipv4Addr) -> bool {
151 let octets = ip.octets();
152 octets[0] == 169 && octets[1] == 254
153}
154
155fn choose_default_bind_v4() -> Option<Ipv4Addr> {
156 let ifaces = get_if_addrs().ok()?;
157 for iface in ifaces {
158 if let IfAddr::V4(v4) = iface.addr {
159 let ip = v4.ip;
160 if ip.is_loopback() || ipv4_is_link_local(ip) {
161 continue;
162 }
163 return Some(ip);
164 }
165 }
166 None
167}
168
169fn choose_default_bind_v6() -> Option<Ipv6Addr> {
170 let ifaces = get_if_addrs().ok()?;
171 for iface in ifaces {
172 if let IfAddr::V6(v6) = iface.addr {
173 let ip = v6.ip;
174 if ip.is_loopback() {
175 continue;
176 }
177 let segs = ip.segments();
179 if segs[0] & 0xffc0 == 0xfe80 {
180 continue;
181 }
182 return Some(ip);
183 }
184 }
185 None
186}
187
188fn broadcast_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
189 let ip_u = u32::from(ip);
190 let mask_u = u32::from(netmask);
191 Ipv4Addr::from(ip_u | !mask_u)
192}
193
194fn discovery_target_for(ip: Ipv4Addr, netmask: Ipv4Addr) -> Ipv4Addr {
195 let limited_broadcast = Ipv4Addr::new(255, 255, 255, 255);
196 if netmask == Ipv4Addr::new(255, 255, 255, 255) || netmask.is_unspecified() {
197 return limited_broadcast;
198 }
199 let directed = broadcast_for(ip, netmask);
200 if directed == ip {
201 limited_broadcast
202 } else {
203 directed
204 }
205}
206
207pub fn build_auto_broadcast_targets() -> Vec<SearchTarget> {
208 let mut targets = Vec::new();
209 let mut fallback_targets = Vec::new();
210 let mut fallback_seen = HashSet::new();
211 let mut added_v4_multicast = false;
212 let mut added_v6_multicast = false;
213 let ifaces = match get_if_addrs() {
214 Ok(v) => v,
215 Err(_) => return targets,
216 };
217 for iface in &ifaces {
218 if let IfAddr::V4(v4) = &iface.addr {
219 let ip = v4.ip;
220 if ip.is_loopback() || ipv4_is_link_local(ip) {
221 continue;
222 }
223 let bcast = discovery_target_for(ip, v4.netmask);
224 targets.push(SearchTarget {
225 target: IpAddr::V4(bcast),
226 bind: IpAddr::V4(ip),
227 });
228 targets.push(SearchTarget {
231 target: IpAddr::V4(PVA_MULTICAST_V4),
232 bind: IpAddr::V4(ip),
233 });
234 if fallback_seen.insert(IpAddr::V4(bcast)) {
235 fallback_targets.push(SearchTarget {
236 target: IpAddr::V4(bcast),
237 bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
238 });
239 }
240 if !added_v4_multicast {
241 added_v4_multicast = true;
242 fallback_targets.push(SearchTarget {
243 target: IpAddr::V4(PVA_MULTICAST_V4),
244 bind: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
245 });
246 }
247 }
248 }
249 for iface in &ifaces {
251 if let IfAddr::V6(v6) = &iface.addr {
252 let ip = v6.ip;
253 if ip.is_loopback() {
254 continue;
255 }
256 let segs = ip.segments();
257 if segs[0] & 0xffc0 == 0xfe80 {
258 continue; }
260 let multicast_target = IpAddr::V6(PVA_MULTICAST_V6);
261 targets.push(SearchTarget {
262 target: multicast_target,
263 bind: IpAddr::V6(ip),
264 });
265 if !added_v6_multicast {
266 added_v6_multicast = true;
267 fallback_targets.push(SearchTarget {
268 target: multicast_target,
269 bind: IpAddr::V6(Ipv6Addr::UNSPECIFIED),
270 });
271 }
272 }
273 }
274 targets.extend(fallback_targets);
275 targets
276}
277
278const PVA_MULTICAST_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 128);
280
281const PVA_MULTICAST_V6: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0x42, 1);
283
284fn join_multicast_any(socket: &std::net::UdpSocket, bind: IpAddr) {
286 match bind {
287 IpAddr::V4(iface) => {
288 let _ = socket.join_multicast_v4(&PVA_MULTICAST_V4, &iface);
289 }
290 IpAddr::V6(_) => {
291 let _ = socket.join_multicast_v6(&PVA_MULTICAST_V6, 0);
293 }
294 }
295}
296
297fn decode_search_response_addr(addr: [u8; 16], port: u16, src: SocketAddr) -> SocketAddr {
298 socket_addr_from_pva_bytes(addr, port)
299 .filter(|a| !a.ip().is_unspecified())
300 .unwrap_or_else(|| SocketAddr::new(src.ip(), port))
301}
302
303fn normalize_discovered_servers(items: Vec<DiscoveredServer>) -> Vec<DiscoveredServer> {
304 let mut seen = HashSet::new();
305 let mut out = Vec::new();
306 for item in items {
307 if seen.insert((item.guid, item.tcp_addr)) {
308 out.push(item);
309 }
310 }
311 out.sort_by(|a, b| a.tcp_addr.to_string().cmp(&b.tcp_addr.to_string()));
312 out
313}
314
315fn bind_udp_reuse(addr: SocketAddr) -> std::io::Result<std::net::UdpSocket> {
322 let domain = if addr.is_ipv4() {
323 Domain::IPV4
324 } else {
325 Domain::IPV6
326 };
327 let sock = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
328 #[cfg(unix)]
329 sock.set_reuse_address(true)?;
330 sock.set_nonblocking(true)?;
331 sock.bind(&addr.into())?;
332 Ok(sock.into())
333}
334
335pub async fn search_pv(
336 pv_name: &str,
337 udp_port: u16,
338 timeout_dur: Duration,
339 targets: &[SearchTarget],
340 debug_enabled: bool,
341) -> Result<SocketAddr, PvGetError> {
342 if targets.is_empty() {
343 return Err(PvGetError::Search("no search targets"));
344 }
345
346 let now = std::time::SystemTime::now()
347 .duration_since(std::time::UNIX_EPOCH)
348 .unwrap_or_default();
349 let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
350 let cid = seq ^ 0x9E37_79B9;
351
352 let mut last_io_error: Option<std::io::Error> = None;
353 let deadline = tokio::time::Instant::now() + timeout_dur;
354
355 let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
357 for t in targets {
358 if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
359 group.1.push(t.target);
360 } else {
361 bind_groups.push((t.bind, vec![t.target]));
362 }
363 }
364
365 let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
368
369 for (bind_ip, group_targets) in &bind_groups {
370 let bind_addr = SocketAddr::new(*bind_ip, udp_port);
371 let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
372 Ok(sock) => (sock, bind_addr),
373 Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
374 let fallback = SocketAddr::new(*bind_ip, 0);
375 match bind_udp_reuse(fallback) {
376 Ok(sock) => {
377 let actual = sock.local_addr().unwrap_or(fallback);
378 if debug_enabled {
379 debug!(
380 "pva search bind={} failed (in use), fallback bind={}",
381 bind_addr, actual
382 );
383 }
384 (sock, actual)
385 }
386 Err(fallback_err) => {
387 if debug_enabled {
388 debug!(
389 "pva search skipping bind={} step=bind-fallback kind={:?} err={}",
390 bind_addr,
391 fallback_err.kind(),
392 fallback_err
393 );
394 }
395 last_io_error = Some(fallback_err);
396 continue;
397 }
398 }
399 }
400 Err(err) => {
401 if debug_enabled {
402 debug!(
403 "pva search skipping bind={} step=bind kind={:?} err={}",
404 bind_addr,
405 err.kind(),
406 err
407 );
408 }
409 last_io_error = Some(err);
410 continue;
411 }
412 };
413 if let Err(err) = std_sock.set_broadcast(true) {
414 if debug_enabled {
415 debug!(
416 "pva search skipping bind={} step=set_broadcast kind={:?} err={}",
417 bind_addr,
418 err.kind(),
419 err
420 );
421 }
422 last_io_error = Some(err);
423 continue;
424 }
425
426 join_multicast_any(&std_sock, *bind_ip);
427
428 let reply_addr = ip_to_bytes(*bind_ip);
429 let reply_port = match std_sock.local_addr() {
430 Ok(addr) => addr.port(),
431 Err(err) => {
432 if debug_enabled {
433 debug!(
434 "pva search skipping bind={} step=local_addr kind={:?} err={}",
435 bind_addr,
436 err.kind(),
437 err
438 );
439 }
440 last_io_error = Some(err);
441 continue;
442 }
443 };
444 let requests = [(cid, pv_name)];
445 let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &requests, 2, false);
446
447 let socket = match UdpSocket::from_std(std_sock) {
448 Ok(socket) => socket,
449 Err(err) => {
450 if debug_enabled {
451 debug!(
452 "pva search skipping bind={} step=from_std kind={:?} err={}",
453 bind_addr,
454 err.kind(),
455 err
456 );
457 }
458 last_io_error = Some(err);
459 continue;
460 }
461 };
462
463 let dests: Vec<SocketAddr> = group_targets
464 .iter()
465 .map(|ip| SocketAddr::new(*ip, udp_port))
466 .collect();
467
468 for dest in &dests {
470 if debug_enabled {
471 debug!(
472 "pva search bind={} target={} server_port={} reply_port={}",
473 actual_bind_addr, dest.ip(), udp_port, reply_port
474 );
475 debug!("pva search seq={} cid={}", seq, cid);
476 debug!("pva search send {} bytes to {}", msg.len(), dest);
477 }
478 if let Err(err) = socket.send_to(&msg, dest).await {
479 if debug_enabled {
480 debug!(
481 "pva search send_to target={} kind={:?} err={}",
482 dest,
483 err.kind(),
484 err
485 );
486 }
487 last_io_error = Some(err);
488 }
489 }
490
491 socket_info.push((Arc::new(socket), msg, dests));
492 }
493
494 if socket_info.is_empty() {
495 if let Some(err) = last_io_error {
496 return Err(PvGetError::Io(err));
497 }
498 return Err(PvGetError::Timeout("search response"));
499 }
500
501 let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
503 for (sock, _, _) in &socket_info {
504 let sock = Arc::clone(sock);
505 let tx = tx.clone();
506 tokio::spawn(async move {
507 loop {
508 let mut buf = vec![0u8; 2048];
509 match sock.recv_from(&mut buf).await {
510 Ok((len, src)) => {
511 buf.truncate(len);
512 if tx.send((buf, src)).await.is_err() {
513 break;
514 }
515 }
516 Err(_) => break,
517 }
518 }
519 });
520 }
521 drop(tx); let retransmit_offsets = [100u64, 500, 1000, 2000];
525 let start = tokio::time::Instant::now();
526 let mut next_retransmit = 0usize;
527
528 loop {
529 let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
531 start + Duration::from_millis(retransmit_offsets[next_retransmit])
532 } else {
533 deadline
534 };
535 let wake_at = next_retransmit_at.min(deadline);
536
537 tokio::select! {
538 recv = rx.recv() => {
539 let Some((buf, src)) = recv else { break };
540 let mut pkt = PvaPacket::new(&buf);
541 let cmd = pkt
542 .decode_payload()
543 .ok_or(PvGetError::Search("failed to decode search response"))?;
544 if let PvaPacketCommand::SearchResponse(payload) = cmd {
545 if debug_enabled {
546 debug!(
547 "pva search response found={} cids={:?} addr={:?} port={}",
548 payload.found, payload.cids, payload.addr, payload.port
549 );
550 }
551 if payload.seq != seq {
552 continue;
553 }
554 if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
555 continue;
556 }
557 if !payload.found {
558 continue;
559 }
560 if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
561 continue;
562 }
563
564 let addr = decode_search_response_addr(payload.addr, payload.port, src);
565 if debug_enabled {
566 debug!("pva search response from {}", addr);
567 }
568 return Ok(addr);
569 }
570 }
571 _ = tokio::time::sleep_until(wake_at) => {
572 if tokio::time::Instant::now() >= deadline {
573 break;
574 }
575 if next_retransmit < retransmit_offsets.len() {
577 if debug_enabled {
578 debug!("pva search retransmit round {}", next_retransmit + 1);
579 }
580 for (sock, msg, dests) in &socket_info {
581 for dest in dests {
582 let _ = sock.send_to(msg, dest).await;
583 }
584 }
585 next_retransmit += 1;
586 }
587 }
588 }
589 }
590
591 Err(PvGetError::Timeout("search response"))
592}
593
594pub fn default_bind_ip() -> Option<IpAddr> {
595 choose_default_bind_v4()
596 .map(IpAddr::V4)
597 .or_else(|| choose_default_bind_v6().map(IpAddr::V6))
598}
599
600pub fn parse_name_servers(env_val: &str) -> Vec<SocketAddr> {
604 let mut out = Vec::new();
605 for token in env_val.split(|c| c == ',' || c == ' ' || c == '\t') {
606 let token = token.trim();
607 if token.is_empty() {
608 continue;
609 }
610 if let Ok(addr) = token.parse::<SocketAddr>() {
611 out.push(addr);
612 continue;
613 }
614 if let Ok(ip) = token.parse::<IpAddr>() {
615 out.push(SocketAddr::new(ip, 5075));
616 continue;
617 }
618 use std::net::ToSocketAddrs;
619 if let Ok(mut addrs) = token.to_socket_addrs() {
620 if let Some(addr) = addrs.next() {
621 out.push(addr);
622 continue;
623 }
624 }
625 let with_port = format!("{}:5075", token);
626 if let Ok(mut addrs) = with_port.to_socket_addrs() {
627 if let Some(addr) = addrs.next() {
628 out.push(addr);
629 }
630 }
631 }
632 out
633}
634
635fn encode_search_validation(version: u8, is_be: bool) -> Vec<u8> {
637 let user = default_authnz_user();
638 let host = default_authnz_host();
639 encode_client_connection_validation(87_040, 32_767, 0, "ca", &user, &host, version, is_be)
640}
641
642pub async fn search_pv_tcp(
647 pv_name: &str,
648 name_server: SocketAddr,
649 timeout_dur: Duration,
650 debug_enabled: bool,
651) -> Result<SocketAddr, PvGetError> {
652 let deadline = tokio::time::Instant::now() + timeout_dur;
653
654 let mut stream =
655 tokio::time::timeout(timeout_dur, tokio::net::TcpStream::connect(name_server))
656 .await
657 .map_err(|_| PvGetError::Timeout("name server connect"))??;
658
659 let mut version = 2u8;
660 let mut is_be = false;
661
662 for _ in 0..2 {
664 let now = tokio::time::Instant::now();
665 if now >= deadline {
666 return Err(PvGetError::Timeout("name server handshake"));
667 }
668 let remaining = deadline - now;
669 if let Ok(bytes) = read_packet(&mut stream, remaining).await {
670 let mut pkt = PvaPacket::new(&bytes);
671 if let Some(cmd) = pkt.decode_payload() {
672 match cmd {
673 PvaPacketCommand::Control(payload) => {
674 if payload.command == 2 {
675 is_be = pkt.header.flags.is_msb;
676 }
677 }
678 PvaPacketCommand::ConnectionValidation(_) => {
679 version = pkt.header.version;
680 is_be = pkt.header.flags.is_msb;
681 }
682 _ => {}
683 }
684 }
685 }
686 }
687
688 let validation = encode_search_validation(version, is_be);
689 stream.write_all(&validation).await?;
690
691 loop {
693 let now = tokio::time::Instant::now();
694 if now >= deadline {
695 return Err(PvGetError::Timeout("name server validated"));
696 }
697 let remaining = deadline - now;
698 let bytes = read_packet(&mut stream, remaining).await?;
699 let mut pkt = PvaPacket::new(&bytes);
700 if let Some(cmd) = pkt.decode_payload() {
701 if matches!(cmd, PvaPacketCommand::ConnectionValidated(_)) {
702 break;
703 }
704 }
705 }
706
707 let now_ts = std::time::SystemTime::now()
709 .duration_since(std::time::UNIX_EPOCH)
710 .unwrap_or_default();
711 let seq = (now_ts.as_nanos() as u32).wrapping_add(std::process::id());
712 let cid = seq ^ 0x9E37_79B9;
713 let requests = [(cid, pv_name)];
714 let msg = encode_search_request(seq, 0x80, 0, [0u8; 16], &requests, version, is_be);
715 stream.write_all(&msg).await?;
716
717 if debug_enabled {
718 debug!(
719 "pva tcp search sent to name_server={} pv={}",
720 name_server, pv_name
721 );
722 }
723
724 loop {
726 let now = tokio::time::Instant::now();
727 if now >= deadline {
728 return Err(PvGetError::Timeout("name server search response"));
729 }
730 let remaining = deadline - now;
731 let bytes = read_packet(&mut stream, remaining).await?;
732 let mut pkt = PvaPacket::new(&bytes);
733 if let Some(cmd) = pkt.decode_payload() {
734 if let PvaPacketCommand::SearchResponse(payload) = cmd {
735 if !payload.found {
736 continue;
737 }
738 if !payload.cids.is_empty() && !payload.cids.contains(&cid) {
739 continue;
740 }
741 let addr =
742 decode_search_response_addr(payload.addr, payload.port, name_server);
743 if debug_enabled {
744 debug!(
745 "pva tcp search response from name_server={}: {}",
746 name_server, addr
747 );
748 }
749 return Ok(addr);
750 }
751 }
752 }
753}
754
755pub async fn resolve_pv_server(opts: &PvGetOptions) -> Result<SocketAddr, PvGetError> {
762 if let Some(addr) = opts.server_addr {
763 return Ok(addr);
764 }
765
766 let mut name_servers = opts.name_servers.clone();
767 if let Ok(env) = std::env::var("EPICS_PVA_NAME_SERVERS") {
768 name_servers.extend(parse_name_servers(&env));
769 }
770
771 let no_broadcast = opts.no_broadcast;
772
773 if no_broadcast && name_servers.is_empty() {
775 return Err(PvGetError::Search(
776 "no search strategy: specify --name-server or --server when using --no-broadcast",
777 ));
778 }
779
780 let targets = build_search_targets(opts.search_addr, opts.bind_addr);
783
784 let pv = opts.pv_name.clone();
785 let timeout_dur = opts.timeout;
786 let debug_enabled = opts.debug;
787 let udp_port = opts.udp_port;
788
789 let mut set = tokio::task::JoinSet::new();
790
791 for ns in name_servers {
792 let pv = pv.clone();
793 set.spawn(async move {
794 let addr = search_pv_tcp(&pv, ns, timeout_dur, debug_enabled).await?;
795 Ok::<SocketAddr, PvGetError>(addr)
796 });
797 }
798
799 if !no_broadcast {
800 let pv = pv.clone();
801 let targets = targets.clone();
802 set.spawn(async move {
803 let addr = search_pv(&pv, udp_port, timeout_dur, &targets, debug_enabled).await?;
804 Ok(addr)
805 });
806 }
807
808 let mut last_err = None;
809 while let Some(result) = set.join_next().await {
810 match result {
811 Ok(Ok(addr)) => {
812 set.abort_all();
813 return Ok(addr);
814 }
815 Ok(Err(e)) => {
816 if debug_enabled {
817 debug!("pva search strategy failed: {}", e);
818 }
819 last_err = Some(e);
820 }
821 Err(join_err) => {
822 if debug_enabled {
823 debug!("pva search task panicked: {}", join_err);
824 }
825 }
826 }
827 }
828
829 Err(last_err.unwrap_or(PvGetError::Timeout("search response")))
830}
831
832pub async fn discover_servers(
833 udp_port: u16,
834 timeout_dur: Duration,
835 targets: &[SearchTarget],
836 debug_enabled: bool,
837) -> Result<Vec<DiscoveredServer>, PvGetError> {
838 if targets.is_empty() {
839 return Err(PvGetError::Search("no search targets"));
840 }
841
842 let now = std::time::SystemTime::now()
843 .duration_since(std::time::UNIX_EPOCH)
844 .unwrap_or_default();
845 let seq = (now.as_nanos() as u32).wrapping_add(std::process::id());
846
847 let mut found: Vec<DiscoveredServer> = Vec::new();
848 let mut last_io_error: Option<std::io::Error> = None;
849 let deadline = tokio::time::Instant::now() + timeout_dur;
850
851 let mut bind_groups: Vec<(IpAddr, Vec<IpAddr>)> = Vec::new();
853 for t in targets {
854 if let Some(group) = bind_groups.iter_mut().find(|(b, _)| *b == t.bind) {
855 group.1.push(t.target);
856 } else {
857 bind_groups.push((t.bind, vec![t.target]));
858 }
859 }
860
861 let mut socket_info: Vec<(Arc<UdpSocket>, Vec<u8>, Vec<SocketAddr>)> = Vec::new();
864
865 for (bind_ip, group_targets) in &bind_groups {
866 let bind_addr = SocketAddr::new(*bind_ip, udp_port);
867 let (std_sock, actual_bind_addr) = match bind_udp_reuse(bind_addr) {
868 Ok(sock) => (sock, bind_addr),
869 Err(err) if err.kind() == std::io::ErrorKind::AddrInUse => {
870 let fallback = SocketAddr::new(*bind_ip, 0);
871 match bind_udp_reuse(fallback) {
872 Ok(sock) => {
873 let actual = sock.local_addr().unwrap_or(fallback);
874 if debug_enabled {
875 debug!(
876 "pva discover bind={} failed (in use), fallback bind={}",
877 bind_addr, actual
878 );
879 }
880 (sock, actual)
881 }
882 Err(fallback_err) => {
883 if debug_enabled {
884 debug!(
885 "pva discover skipping bind={} step=bind-fallback kind={:?} err={}",
886 bind_addr,
887 fallback_err.kind(),
888 fallback_err
889 );
890 }
891 last_io_error = Some(fallback_err);
892 continue;
893 }
894 }
895 }
896 Err(err) => {
897 if debug_enabled {
898 debug!(
899 "pva discover skipping bind={} step=bind kind={:?} err={}",
900 bind_addr,
901 err.kind(),
902 err
903 );
904 }
905 last_io_error = Some(err);
906 continue;
907 }
908 };
909 if let Err(err) = std_sock.set_broadcast(true) {
910 if debug_enabled {
911 debug!(
912 "pva discover skipping bind={} step=set_broadcast kind={:?} err={}",
913 bind_addr,
914 err.kind(),
915 err
916 );
917 }
918 last_io_error = Some(err);
919 continue;
920 }
921
922 join_multicast_any(&std_sock, *bind_ip);
923
924 let reply_addr = ip_to_bytes(*bind_ip);
925 let reply_port = match std_sock.local_addr() {
926 Ok(addr) => addr.port(),
927 Err(err) => {
928 if debug_enabled {
929 debug!(
930 "pva discover skipping bind={} step=local_addr kind={:?} err={}",
931 bind_addr,
932 err.kind(),
933 err
934 );
935 }
936 last_io_error = Some(err);
937 continue;
938 }
939 };
940 let msg = encode_search_request(seq, 0x81, reply_port, reply_addr, &[], 2, false);
941
942 let socket = match UdpSocket::from_std(std_sock) {
943 Ok(socket) => socket,
944 Err(err) => {
945 if debug_enabled {
946 debug!(
947 "pva discover skipping bind={} step=from_std kind={:?} err={}",
948 bind_addr,
949 err.kind(),
950 err
951 );
952 }
953 last_io_error = Some(err);
954 continue;
955 }
956 };
957
958 let dests: Vec<SocketAddr> = group_targets
959 .iter()
960 .map(|ip| SocketAddr::new(*ip, udp_port))
961 .collect();
962
963 for dest in &dests {
965 if debug_enabled {
966 debug!(
967 "pva discover bind={} target={} server_port={} reply_port={} seq={}",
968 actual_bind_addr, dest.ip(), udp_port, reply_port, seq
969 );
970 }
971 if let Err(err) = socket.send_to(&msg, dest).await {
972 if debug_enabled {
973 debug!(
974 "pva discover send_to target={} kind={:?} err={}",
975 dest,
976 err.kind(),
977 err
978 );
979 }
980 last_io_error = Some(err);
981 }
982 }
983
984 socket_info.push((Arc::new(socket), msg, dests));
985 }
986
987 if socket_info.is_empty() {
988 if let Some(err) = last_io_error {
989 return Err(PvGetError::Io(err));
990 }
991 return Err(PvGetError::Search("no search targets"));
992 }
993
994 let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, SocketAddr)>(64);
996 for (sock, _, _) in &socket_info {
997 let sock = Arc::clone(sock);
998 let tx = tx.clone();
999 tokio::spawn(async move {
1000 loop {
1001 let mut buf = vec![0u8; 2048];
1002 match sock.recv_from(&mut buf).await {
1003 Ok((len, src)) => {
1004 buf.truncate(len);
1005 if tx.send((buf, src)).await.is_err() {
1006 break;
1007 }
1008 }
1009 Err(_) => break,
1010 }
1011 }
1012 });
1013 }
1014 drop(tx); let retransmit_offsets = [100u64, 500, 1000, 2000];
1018 let start = tokio::time::Instant::now();
1019 let mut next_retransmit = 0usize;
1020
1021 loop {
1022 let next_retransmit_at = if next_retransmit < retransmit_offsets.len() {
1024 start + Duration::from_millis(retransmit_offsets[next_retransmit])
1025 } else {
1026 deadline
1027 };
1028 let wake_at = next_retransmit_at.min(deadline);
1029
1030 tokio::select! {
1031 recv = rx.recv() => {
1032 let Some((buf, src)) = recv else { break };
1033 let mut pkt = PvaPacket::new(&buf);
1034 let Some(cmd) = pkt.decode_payload() else {
1035 continue;
1036 };
1037 if let PvaPacketCommand::SearchResponse(payload) = cmd {
1038 if payload.seq != seq {
1039 continue;
1040 }
1041 if !payload.protocol.is_empty() && !payload.protocol.eq_ignore_ascii_case("tcp") {
1042 continue;
1043 }
1044 let tcp_addr = decode_search_response_addr(payload.addr, payload.port, src);
1045 found.push(DiscoveredServer {
1046 guid: payload.guid,
1047 tcp_addr,
1048 });
1049 }
1050 }
1051 _ = tokio::time::sleep_until(wake_at) => {
1052 if tokio::time::Instant::now() >= deadline {
1053 break;
1054 }
1055 if next_retransmit < retransmit_offsets.len() {
1057 if debug_enabled {
1058 debug!("pva discover retransmit round {}", next_retransmit + 1);
1059 }
1060 for (sock, msg, dests) in &socket_info {
1061 for dest in dests {
1062 let _ = sock.send_to(msg, dest).await;
1063 }
1064 }
1065 next_retransmit += 1;
1066 }
1067 }
1068 }
1069 }
1070
1071 Ok(normalize_discovered_servers(found))
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076 use super::*;
1077 use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand};
1078
1079 #[test]
1080 fn encode_decode_search_request_roundtrip() {
1081 let seq = 1234;
1082 let cid = 42;
1083 let port = 5076;
1084 let pv_name = "TEST:PV";
1085 let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 20)));
1086 let requests = [(cid, pv_name)];
1087 let msg = encode_search_request(seq, 0x81, port, reply_addr, &requests, 2, false);
1088 let mut pkt = PvaPacket::new(&msg);
1089 let cmd = pkt.decode_payload().expect("decoded");
1090 match cmd {
1091 PvaPacketCommand::Search(payload) => {
1092 assert_eq!(payload.seq, seq);
1093 assert_eq!(payload.mask, 0x81);
1094 assert_eq!(payload.addr, reply_addr);
1095 assert_eq!(payload.port, port);
1096 assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1097 assert_eq!(payload.pv_requests.len(), 1);
1098 assert_eq!(payload.pv_requests[0].0, cid);
1099 assert_eq!(payload.pv_requests[0].1, pv_name.to_string());
1100 }
1101 other => panic!("unexpected decode: {:?}", other),
1102 }
1103 }
1104
1105 #[test]
1106 fn encode_decode_server_discovery_request_roundtrip() {
1107 let seq = 4321;
1108 let port = 5076;
1109 let reply_addr = ip_to_bytes(IpAddr::V4(Ipv4Addr::new(10, 20, 30, 40)));
1110 let msg = encode_search_request(seq, 0x81, port, reply_addr, &[], 2, false);
1111 let mut pkt = PvaPacket::new(&msg);
1112 let cmd = pkt.decode_payload().expect("decoded");
1113 match cmd {
1114 PvaPacketCommand::Search(payload) => {
1115 assert_eq!(payload.seq, seq);
1116 assert_eq!(payload.pv_requests.len(), 0);
1117 assert_eq!(payload.protocols, vec!["tcp".to_string()]);
1118 }
1119 other => panic!("unexpected decode: {:?}", other),
1120 }
1121 }
1122
1123 #[test]
1124 fn normalize_discovered_servers_deduplicates_by_guid_and_addr() {
1125 let guid = [1u8; 12];
1126 let s1 = DiscoveredServer {
1127 guid,
1128 tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1129 };
1130 let s2 = DiscoveredServer {
1131 guid,
1132 tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1133 };
1134 let s3 = DiscoveredServer {
1135 guid: [2u8; 12],
1136 tcp_addr: "127.0.0.1:5075".parse().unwrap(),
1137 };
1138 let normalized = normalize_discovered_servers(vec![s1, s2, s3]);
1139 assert_eq!(normalized.len(), 2);
1140 }
1141
1142 #[test]
1143 fn parse_addr_list_accepts_ip_and_ip_port() {
1144 let items = parse_addr_list("192.168.1.10 10.0.0.1:5076");
1145 assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10))));
1146 assert!(items.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
1147 }
1148
1149 #[test]
1150 fn discovery_target_falls_back_to_limited_broadcast_for_invalid_netmask() {
1151 let ip = Ipv4Addr::new(130, 246, 90, 92);
1152 assert_eq!(
1153 discovery_target_for(ip, Ipv4Addr::new(255, 255, 255, 255)),
1154 Ipv4Addr::new(255, 255, 255, 255)
1155 );
1156 assert_eq!(
1157 discovery_target_for(ip, Ipv4Addr::new(0, 0, 0, 0)),
1158 Ipv4Addr::new(255, 255, 255, 255)
1159 );
1160 }
1161
1162 #[test]
1163 fn discovery_target_uses_directed_broadcast_for_normal_subnet() {
1164 let ip = Ipv4Addr::new(192, 168, 56, 1);
1165 let netmask = Ipv4Addr::new(255, 255, 255, 0);
1166 assert_eq!(
1167 discovery_target_for(ip, netmask),
1168 Ipv4Addr::new(192, 168, 56, 255)
1169 );
1170 }
1171
1172 #[test]
1173 fn parse_name_servers_ip_with_port() {
1174 let addrs = parse_name_servers("192.168.1.10:5075");
1175 assert_eq!(addrs, vec!["192.168.1.10:5075".parse::<SocketAddr>().unwrap()]);
1176 }
1177
1178 #[test]
1179 fn parse_name_servers_ip_without_port_defaults_to_5075() {
1180 let addrs = parse_name_servers("10.0.0.1");
1181 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5075)]);
1182 }
1183
1184 #[test]
1185 fn parse_name_servers_multiple_comma_separated() {
1186 let addrs = parse_name_servers("10.0.0.1:5075,10.0.0.2:9876");
1187 assert_eq!(addrs.len(), 2);
1188 assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1189 assert_eq!(addrs[1], "10.0.0.2:9876".parse::<SocketAddr>().unwrap());
1190 }
1191
1192 #[test]
1193 fn parse_name_servers_multiple_space_separated() {
1194 let addrs = parse_name_servers("10.0.0.1 10.0.0.2:5075");
1195 assert_eq!(addrs.len(), 2);
1196 assert_eq!(addrs[0], SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 5075));
1197 assert_eq!(addrs[1], "10.0.0.2:5075".parse::<SocketAddr>().unwrap());
1198 }
1199
1200 #[test]
1201 fn parse_name_servers_empty_string() {
1202 let addrs = parse_name_servers("");
1203 assert!(addrs.is_empty());
1204 }
1205
1206 #[test]
1207 fn parse_name_servers_whitespace_only() {
1208 let addrs = parse_name_servers(" \t ");
1209 assert!(addrs.is_empty());
1210 }
1211
1212 #[test]
1213 fn parse_name_servers_mixed_separators() {
1214 let addrs = parse_name_servers("10.0.0.1:5075, 10.0.0.2 , 10.0.0.3:9999");
1215 assert_eq!(addrs.len(), 3);
1216 assert_eq!(addrs[0], "10.0.0.1:5075".parse::<SocketAddr>().unwrap());
1217 assert_eq!(addrs[1], SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)), 5075));
1218 assert_eq!(addrs[2], "10.0.0.3:9999".parse::<SocketAddr>().unwrap());
1219 }
1220
1221 #[test]
1222 fn parse_name_servers_ipv6_with_port() {
1223 let addrs = parse_name_servers("[::1]:5075");
1224 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]);
1225 }
1226
1227 #[test]
1228 fn parse_name_servers_ipv6_without_port() {
1229 let addrs = parse_name_servers("::1");
1230 assert_eq!(addrs, vec![SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 5075)]);
1231 }
1232
1233 #[test]
1234 fn decode_search_response_addr_falls_back_to_udp_source_when_unspecified() {
1235 let src: SocketAddr = "192.168.1.20:5076".parse().unwrap();
1236 let decoded = decode_search_response_addr([0u8; 16], 5075, src);
1237 assert_eq!(decoded, "192.168.1.20:5075".parse().unwrap());
1238 }
1239}