1use std::collections::{HashMap, HashSet};
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::time::Duration;
47
48use gen_fsm::{Action, EventType, FsmDriver, FsmHandler, Transition};
49use parking_lot::Mutex;
50use serde::{Deserialize, Serialize};
51use tokio::sync::mpsc;
52
53use dynvec::SearchResult;
54
55use crate::cluster::apl::{walk_n_successors, ClusterState};
56use crate::embed::events::PeerId;
57
58pub const DEFAULT_PER_PEER_DEADLINE_MS: u64 = 5_000;
66
67#[derive(Clone, Debug, PartialEq)]
69pub struct PeerHits {
70 pub peer: String,
72 pub hits: Vec<SearchResult>,
75}
76
77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
81pub struct SearchRequest {
82 pub table: String,
84 pub vector: Vec<f32>,
86 pub k: usize,
88 pub ef: Option<usize>,
90}
91
92#[derive(Clone, Debug, PartialEq)]
94pub struct SearchResponse {
95 pub hits: Vec<SearchResult>,
97 pub peers_consulted: usize,
99}
100
101pub type PeerProbe =
104 Arc<dyn Fn(&str, SearchRequest) -> Result<Vec<SearchResult>, String> + Send + Sync + 'static>;
105
106#[derive(Debug)]
108pub enum Event {
109 Fanout,
111 Gather,
113 PeerHits(PeerHits),
115 GatherComplete,
117}
118
119#[derive(Clone, Copy, Debug, PartialEq, Eq)]
121pub enum State {
122 Init,
124 Fanout,
126 Gather,
128 Merge,
130}
131
132pub struct Coordinator {
135 request: SearchRequest,
136 peers: Vec<String>,
137 probe: PeerProbe,
138 hits: HashMap<String, Vec<SearchResult>>,
139 response: Arc<Mutex<Option<SearchResponse>>>,
140 deadline: Duration,
144}
145
146impl Coordinator {
147 #[must_use]
156 pub fn new(
157 request: SearchRequest,
158 peers: Vec<String>,
159 probe: PeerProbe,
160 deadline: Duration,
161 ) -> (Self, Arc<Mutex<Option<SearchResponse>>>) {
162 let response = Arc::new(Mutex::new(None));
163 let coord = Self {
164 request,
165 peers,
166 probe,
167 hits: HashMap::new(),
168 response: Arc::clone(&response),
169 deadline,
170 };
171 (coord, response)
172 }
173}
174
175impl FsmHandler for Coordinator {
176 type State = State;
177 type Event = Event;
178 type Reply = ();
179 type Stop = String;
180
181 fn initial(&self) -> Self::State {
182 State::Init
183 }
184
185 fn handle(
186 &mut self,
187 state: Self::State,
188 _event_type: EventType,
189 event: Self::Event,
190 ) -> Transition<Self> {
191 match (state, event) {
192 (State::Init, Event::Fanout) => {
193 Transition::Next(State::Fanout, vec![Action::post_internal(Event::Gather)])
194 }
195 (State::Fanout, Event::Gather) => {
196 let mut completion: Vec<Action<Self>> = Vec::new();
199 for peer in self.peers.clone() {
200 let res = (self.probe)(&peer, self.request.clone());
201 match res {
202 Ok(hits) => {
203 completion.push(Action::post_internal(Event::PeerHits(PeerHits {
204 peer,
205 hits,
206 })));
207 }
208 Err(err) => {
209 tracing::warn!(peer=%peer, error=%err, "peer probe failed");
210 completion.push(Action::post_internal(Event::PeerHits(PeerHits {
213 peer,
214 hits: Vec::new(),
215 })));
216 }
217 }
218 }
219 completion.push(Action::set_state_timeout(self.deadline));
220 if completion.is_empty() {
221 Transition::Next(
222 State::Merge,
223 vec![Action::post_internal(Event::GatherComplete)],
224 )
225 } else {
226 Transition::Next(State::Gather, completion)
227 }
228 }
229 (State::Gather, Event::PeerHits(reply)) => {
230 self.hits.insert(reply.peer, reply.hits);
231 if self.hits.len() >= self.peers.len() {
232 Transition::Next(
233 State::Merge,
234 vec![Action::post_internal(Event::GatherComplete)],
235 )
236 } else {
237 Transition::Keep(vec![])
238 }
239 }
240 (State::Merge, Event::GatherComplete) => {
241 let merged = merge_hits(&self.hits, self.request.k);
242 let response = SearchResponse {
243 hits: merged,
244 peers_consulted: self.hits.values().filter(|h| !h.is_empty()).count(),
245 };
246 *self.response.lock() = Some(response);
247 Transition::Stop("complete".to_string())
248 }
249 (_, _) => Transition::Keep(vec![]),
251 }
252 }
253
254 fn on_timeout(&mut self, state: Self::State, _kind: gen_fsm::TimeoutKind) -> Transition<Self> {
255 match state {
256 State::Gather => Transition::Next(
257 State::Merge,
258 vec![Action::post_internal(Event::GatherComplete)],
259 ),
260 _ => Transition::Keep(vec![]),
261 }
262 }
263}
264
265#[must_use]
270pub fn merge_hits<S: std::hash::BuildHasher>(
271 per_peer: &HashMap<String, Vec<SearchResult>, S>,
272 k: usize,
273) -> Vec<SearchResult> {
274 let mut all: Vec<SearchResult> = per_peer.values().flatten().cloned().collect();
275 all.sort_by(|a, b| {
276 a.score
277 .partial_cmp(&b.score)
278 .unwrap_or(std::cmp::Ordering::Equal)
279 });
280 let mut seen: HashMap<u64, f32> = HashMap::new();
284 let mut deduped: Vec<SearchResult> = Vec::with_capacity(all.len());
285 for r in all {
286 let entry = seen.entry(r.id).or_insert(r.score);
287 if r.score <= *entry {
288 *entry = r.score;
289 deduped.push(r);
290 }
291 }
292 deduped.sort_by(|a, b| {
298 a.score
299 .partial_cmp(&b.score)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 });
302 let mut final_seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
303 let mut out: Vec<SearchResult> = Vec::with_capacity(k);
304 for r in deduped {
305 if final_seen.insert(r.id) {
306 out.push(r);
307 if out.len() >= k {
308 break;
309 }
310 }
311 }
312 out
313}
314
315pub async fn run(
327 request: SearchRequest,
328 peers: Vec<String>,
329 probe: PeerProbe,
330 deadline: Duration,
331) -> Result<SearchResponse, gen_fsm::DriverError> {
332 let (coord, response) = Coordinator::new(request, peers, probe, deadline);
333 let driver = gen_fsm::FsmDriver::start(coord);
334 driver.cast_checked(Event::Fanout).await?;
335 let _stop = driver.join().await?;
336 let final_resp = response.lock().clone().unwrap_or(SearchResponse {
337 hits: Vec::new(),
338 peers_consulted: 0,
339 });
340 Ok(final_resp)
341}
342
343#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
376pub enum SerializedQuery {
377 Knn {
379 vector_field: String,
381 vector_bytes: Vec<u8>,
383 ef: Option<u32>,
385 },
386 Text {
388 field: String,
390 query: Vec<u8>,
392 },
393 Regex {
395 field: String,
397 pattern: String,
399 max_errors: u16,
401 },
402}
403
404#[derive(Clone, Debug, Default, PartialEq)]
411pub struct HitWithScore {
412 pub doc_id: Vec<u8>,
414 pub score: f32,
417}
418
419#[derive(Clone, Debug, Default, PartialEq)]
427pub struct PeerReply {
428 pub hits: Vec<HitWithScore>,
430 pub timed_out: bool,
432}
433
434#[derive(Clone, Debug, PartialEq, Eq)]
440pub struct BroadcastRequest {
441 pub table: String,
443 pub query: SerializedQuery,
445 pub top_k: u32,
447}
448
449#[derive(Clone, Debug, Default, PartialEq)]
456pub struct BroadcastResponse {
457 pub hits: Vec<HitWithScore>,
459 pub peers_consulted: usize,
462 pub peers_timed_out: usize,
464 pub partial: bool,
467}
468
469#[derive(Clone, Copy, Debug, PartialEq, Eq)]
472pub enum MergeOrder {
473 ScoreAscending,
476 DocIdAscending,
481}
482
483pub type AsyncPeerProbe = Arc<
493 dyn Fn(
494 PeerId,
495 BroadcastRequest,
496 )
497 -> Pin<Box<dyn Future<Output = Result<Vec<HitWithScore>, String>> + Send + 'static>>
498 + Send
499 + Sync
500 + 'static,
501>;
502
503#[must_use]
535pub fn select_primary_peers(cluster: &ClusterState) -> Vec<PeerId> {
536 let len = cluster.ring().len();
537 if len == 0 {
538 return Vec::new();
539 }
540 walk_n_successors(cluster, 0, len)
541 .into_iter()
542 .filter(|(_, pid)| cluster.is_alive(*pid))
543 .map(|(_, pid)| pid)
544 .collect()
545}
546
547#[must_use]
549pub const fn default_per_peer_deadline() -> Duration {
550 Duration::from_millis(DEFAULT_PER_PEER_DEADLINE_MS)
551}
552
553#[must_use]
585pub fn merge_hits_ranked(
586 per_peer: &[PeerReply],
587 top_k: u32,
588 order: MergeOrder,
589) -> Vec<HitWithScore> {
590 let cap = usize::try_from(top_k).unwrap_or(usize::MAX);
591 if cap == 0 {
592 return Vec::new();
593 }
594 let mut all: Vec<HitWithScore> = per_peer
595 .iter()
596 .flat_map(|reply| reply.hits.iter().cloned())
597 .collect();
598 sort_hits(&mut all, order);
599 let mut seen: HashSet<Vec<u8>> = HashSet::with_capacity(all.len().min(cap));
600 let mut out: Vec<HitWithScore> = Vec::with_capacity(cap);
601 for hit in all {
602 if seen.insert(hit.doc_id.clone()) {
603 out.push(hit);
604 if out.len() >= cap {
605 break;
606 }
607 }
608 }
609 out
610}
611
612fn sort_hits(hits: &mut [HitWithScore], order: MergeOrder) {
613 match order {
614 MergeOrder::ScoreAscending => {
615 hits.sort_by(|a, b| {
616 a.score
617 .partial_cmp(&b.score)
618 .unwrap_or(std::cmp::Ordering::Equal)
619 .then_with(|| a.doc_id.cmp(&b.doc_id))
620 });
621 }
622 MergeOrder::DocIdAscending => {
623 hits.sort_by(|a, b| a.doc_id.cmp(&b.doc_id));
624 }
625 }
626}
627
628#[derive(Debug)]
632pub enum BroadcastEvent {
633 PeerReplied(PeerReply),
635 AllReceived,
639 MergeDone,
643}
644
645#[derive(Clone, Copy, Debug, PartialEq, Eq)]
647pub enum BroadcastState {
648 Init,
651 Gathering,
655 Merging,
657}
658
659pub struct BroadcastCoordinator {
665 request: BroadcastRequest,
666 expected_peers: usize,
667 replies: Vec<PeerReply>,
668 order: MergeOrder,
669 response: Arc<Mutex<Option<BroadcastResponse>>>,
670 overall_deadline: Duration,
671}
672
673impl BroadcastCoordinator {
674 #[must_use]
676 pub fn new(
677 request: BroadcastRequest,
678 expected_peers: usize,
679 order: MergeOrder,
680 overall_deadline: Duration,
681 ) -> (Self, Arc<Mutex<Option<BroadcastResponse>>>) {
682 let response = Arc::new(Mutex::new(None));
683 let coord = Self {
684 request,
685 expected_peers,
686 replies: Vec::with_capacity(expected_peers),
687 order,
688 response: Arc::clone(&response),
689 overall_deadline,
690 };
691 (coord, response)
692 }
693
694 fn finalise(&self) -> BroadcastResponse {
695 let timed_out = self.replies.iter().filter(|r| r.timed_out).count();
696 let consulted = self.replies.len();
697 let merged = merge_hits_ranked(&self.replies, self.request.top_k, self.order);
698 BroadcastResponse {
699 hits: merged,
700 peers_consulted: consulted,
701 peers_timed_out: timed_out,
702 partial: timed_out > 0 || consulted < self.expected_peers,
703 }
704 }
705}
706
707impl FsmHandler for BroadcastCoordinator {
708 type State = BroadcastState;
709 type Event = BroadcastEvent;
710 type Reply = ();
711 type Stop = String;
712
713 fn initial(&self) -> Self::State {
714 BroadcastState::Init
715 }
716
717 fn handle(
718 &mut self,
719 state: Self::State,
720 _event_type: EventType,
721 event: Self::Event,
722 ) -> Transition<Self> {
723 match (state, event) {
724 (BroadcastState::Init | BroadcastState::Gathering, BroadcastEvent::PeerReplied(r)) => {
725 self.replies.push(r);
726 if self.replies.len() >= self.expected_peers {
727 Transition::Next(
728 BroadcastState::Merging,
729 vec![
730 Action::cancel_state_timeout(),
731 Action::post_internal(BroadcastEvent::AllReceived),
732 ],
733 )
734 } else if state == BroadcastState::Init {
735 Transition::Next(
736 BroadcastState::Gathering,
737 vec![Action::set_state_timeout(self.overall_deadline)],
738 )
739 } else {
740 Transition::Keep(vec![])
741 }
742 }
743 (BroadcastState::Merging, BroadcastEvent::AllReceived | BroadcastEvent::MergeDone) => {
744 let resp = self.finalise();
745 *self.response.lock() = Some(resp);
746 Transition::Stop("broadcast complete".to_string())
747 }
748 _ => Transition::Keep(vec![]),
754 }
755 }
756
757 fn on_timeout(&mut self, state: Self::State, _kind: gen_fsm::TimeoutKind) -> Transition<Self> {
758 if matches!(state, BroadcastState::Gathering | BroadcastState::Init) {
759 while self.replies.len() < self.expected_peers {
762 self.replies.push(PeerReply {
763 hits: Vec::new(),
764 timed_out: true,
765 });
766 }
767 Transition::Next(
768 BroadcastState::Merging,
769 vec![Action::post_internal(BroadcastEvent::AllReceived)],
770 )
771 } else {
772 Transition::Keep(vec![])
773 }
774 }
775}
776
777pub async fn broadcast(
806 request: BroadcastRequest,
807 peers: Vec<PeerId>,
808 probe: AsyncPeerProbe,
809 per_peer_deadline: Duration,
810 order: MergeOrder,
811) -> Result<BroadcastResponse, gen_fsm::DriverError> {
812 if peers.is_empty() {
813 return Ok(BroadcastResponse {
814 hits: Vec::new(),
815 peers_consulted: 0,
816 peers_timed_out: 0,
817 partial: true,
818 });
819 }
820 let overall = per_peer_deadline
826 .saturating_mul(2)
827 .saturating_add(Duration::from_secs(1));
828 let n = peers.len();
829 let (handler, response) = BroadcastCoordinator::new(request.clone(), n, order, overall);
830 let driver: FsmDriver<BroadcastCoordinator> = FsmDriver::start(handler);
831 let (reply_tx, mut reply_rx) = mpsc::channel::<PeerReply>(n);
832 for peer in peers {
833 let probe = Arc::clone(&probe);
834 let req = request.clone();
835 let tx = reply_tx.clone();
836 tokio::spawn(async move {
837 let fut = probe(peer, req);
838 let reply = match tokio::time::timeout(per_peer_deadline, fut).await {
839 Ok(Ok(hits)) => PeerReply {
840 hits,
841 timed_out: false,
842 },
843 Ok(Err(err)) => {
844 tracing::warn!(peer=peer, error=%err, "FT.SEARCH peer probe failed");
845 PeerReply {
846 hits: Vec::new(),
847 timed_out: false,
848 }
849 }
850 Err(_) => {
851 tracing::warn!(
852 peer = peer,
853 "FT.SEARCH peer probe timed out (per-peer deadline elapsed)"
854 );
855 PeerReply {
856 hits: Vec::new(),
857 timed_out: true,
858 }
859 }
860 };
861 let _ = tx.send(reply).await;
862 });
863 }
864 drop(reply_tx);
865 let driver_for_pump = driver.clone();
866 let pump = tokio::spawn(async move {
867 while let Some(reply) = reply_rx.recv().await {
868 if driver_for_pump
869 .cast_checked(BroadcastEvent::PeerReplied(reply))
870 .await
871 .is_err()
872 {
873 break;
874 }
875 }
876 });
877 let _ = driver.join().await?;
878 let _ = pump.await;
879 let final_resp = response
880 .lock()
881 .clone()
882 .unwrap_or_else(|| BroadcastResponse {
883 hits: Vec::new(),
884 peers_consulted: 0,
885 peers_timed_out: n,
886 partial: true,
887 });
888 Ok(final_resp)
889}
890
891#[cfg(test)]
892mod tests {
893 use super::*;
894 use dynvec::SearchResult;
895
896 fn req() -> SearchRequest {
897 SearchRequest {
898 table: "t".to_string(),
899 vector: vec![0.0; 4],
900 k: 3,
901 ef: None,
902 }
903 }
904
905 #[tokio::test]
906 async fn merges_hits_from_multiple_peers() {
907 let hits_p1 = vec![
908 SearchResult { id: 1, score: 0.1 },
909 SearchResult { id: 2, score: 0.5 },
910 ];
911 let hits_p2 = vec![
912 SearchResult { id: 3, score: 0.2 },
913 SearchResult { id: 4, score: 0.6 },
914 ];
915 let probe: PeerProbe = Arc::new(move |peer, _r| match peer {
916 "p1" => Ok(hits_p1.clone()),
917 "p2" => Ok(hits_p2.clone()),
918 _ => Err("unknown peer".to_string()),
919 });
920 let resp = run(
921 req(),
922 vec!["p1".to_string(), "p2".to_string()],
923 probe,
924 Duration::from_secs(1),
925 )
926 .await
927 .unwrap();
928 assert_eq!(resp.peers_consulted, 2);
929 assert_eq!(resp.hits.len(), 3);
930 assert_eq!(resp.hits[0].id, 1);
931 assert_eq!(resp.hits[1].id, 3);
932 assert_eq!(resp.hits[2].id, 2);
933 }
934
935 #[tokio::test]
936 async fn missing_peers_are_tolerated() {
937 let probe: PeerProbe = Arc::new(|peer, _r| match peer {
938 "good" => Ok(vec![SearchResult { id: 1, score: 0.1 }]),
939 _ => Err("dead".to_string()),
940 });
941 let resp = run(
942 req(),
943 vec!["good".to_string(), "bad".to_string()],
944 probe,
945 Duration::from_secs(1),
946 )
947 .await
948 .unwrap();
949 assert_eq!(resp.peers_consulted, 1);
950 assert_eq!(resp.hits.len(), 1);
951 assert_eq!(resp.hits[0].id, 1);
952 }
953
954 #[tokio::test]
955 async fn duplicate_ids_collapsed() {
956 let probe: PeerProbe = Arc::new(|peer, _r| match peer {
957 "p1" => Ok(vec![SearchResult { id: 1, score: 0.10 }]),
958 "p2" => Ok(vec![SearchResult { id: 1, score: 0.05 }]),
959 _ => Err("unknown".to_string()),
960 });
961 let resp = run(
962 SearchRequest {
963 table: "t".to_string(),
964 vector: vec![],
965 k: 2,
966 ef: None,
967 },
968 vec!["p1".to_string(), "p2".to_string()],
969 probe,
970 Duration::from_secs(1),
971 )
972 .await
973 .unwrap();
974 assert_eq!(resp.hits.len(), 1);
975 assert!((resp.hits[0].score - 0.05).abs() < 1e-6);
976 }
977
978 use std::collections::HashSet;
981
982 use crate::cluster::apl::{ClusterState, RingPoint};
983
984 fn knn_request(top_k: u32) -> BroadcastRequest {
985 BroadcastRequest {
986 table: "idx".into(),
987 query: SerializedQuery::Knn {
988 vector_field: "v".into(),
989 vector_bytes: vec![0u8; 16],
990 ef: None,
991 },
992 top_k,
993 }
994 }
995
996 fn fixed_probe(per_peer: HashMap<PeerId, Vec<HitWithScore>>) -> AsyncPeerProbe {
997 Arc::new(move |peer, _req| {
998 let hits = per_peer.get(&peer).cloned().unwrap_or_default();
999 Box::pin(async move { Ok(hits) })
1000 })
1001 }
1002
1003 #[tokio::test]
1004 async fn merge_score_ascending_picks_smallest_scores() {
1005 let p0 = PeerReply {
1006 hits: vec![
1007 HitWithScore {
1008 doc_id: b"a".to_vec(),
1009 score: 0.1,
1010 },
1011 HitWithScore {
1012 doc_id: b"b".to_vec(),
1013 score: 0.5,
1014 },
1015 ],
1016 timed_out: false,
1017 };
1018 let p1 = PeerReply {
1019 hits: vec![
1020 HitWithScore {
1021 doc_id: b"c".to_vec(),
1022 score: 0.05,
1023 },
1024 HitWithScore {
1025 doc_id: b"d".to_vec(),
1026 score: 0.6,
1027 },
1028 ],
1029 timed_out: false,
1030 };
1031 let merged = merge_hits_ranked(&[p0, p1], 3, MergeOrder::ScoreAscending);
1032 assert_eq!(merged.len(), 3);
1033 assert_eq!(merged[0].doc_id, b"c");
1034 assert_eq!(merged[1].doc_id, b"a");
1035 assert_eq!(merged[2].doc_id, b"b");
1036 }
1037
1038 #[tokio::test]
1039 async fn merge_doc_id_ascending_orders_by_key() {
1040 let p0 = PeerReply {
1041 hits: vec![
1042 HitWithScore {
1043 doc_id: b"key:9".to_vec(),
1044 score: 0.0,
1045 },
1046 HitWithScore {
1047 doc_id: b"key:1".to_vec(),
1048 score: 0.0,
1049 },
1050 ],
1051 timed_out: false,
1052 };
1053 let p1 = PeerReply {
1054 hits: vec![HitWithScore {
1055 doc_id: b"key:5".to_vec(),
1056 score: 0.0,
1057 }],
1058 timed_out: false,
1059 };
1060 let merged = merge_hits_ranked(&[p0, p1], 5, MergeOrder::DocIdAscending);
1061 assert_eq!(
1062 merged.iter().map(|h| h.doc_id.clone()).collect::<Vec<_>>(),
1063 vec![b"key:1".to_vec(), b"key:5".to_vec(), b"key:9".to_vec()],
1064 );
1065 }
1066
1067 #[tokio::test]
1068 async fn merge_dedups_doc_ids_in_score_order() {
1069 let p0 = PeerReply {
1070 hits: vec![HitWithScore {
1071 doc_id: b"a".to_vec(),
1072 score: 0.10,
1073 }],
1074 timed_out: false,
1075 };
1076 let p1 = PeerReply {
1077 hits: vec![HitWithScore {
1078 doc_id: b"a".to_vec(),
1079 score: 0.05,
1080 }],
1081 timed_out: false,
1082 };
1083 let merged = merge_hits_ranked(&[p0, p1], 5, MergeOrder::ScoreAscending);
1084 assert_eq!(merged.len(), 1);
1085 assert!((merged[0].score - 0.05).abs() < 1e-6);
1086 }
1087
1088 #[tokio::test]
1089 async fn merge_top_k_zero_returns_empty() {
1090 let p = PeerReply {
1091 hits: vec![HitWithScore {
1092 doc_id: b"a".to_vec(),
1093 score: 0.1,
1094 }],
1095 timed_out: false,
1096 };
1097 assert!(merge_hits_ranked(&[p], 0, MergeOrder::ScoreAscending).is_empty());
1098 }
1099
1100 #[tokio::test]
1101 async fn broadcast_with_no_peers_returns_partial_empty() {
1102 let probe: AsyncPeerProbe = Arc::new(|_peer, _req| Box::pin(async { Ok(Vec::new()) }));
1103 let resp = broadcast(
1104 knn_request(5),
1105 Vec::new(),
1106 probe,
1107 Duration::from_millis(50),
1108 MergeOrder::ScoreAscending,
1109 )
1110 .await
1111 .unwrap();
1112 assert!(resp.hits.is_empty());
1113 assert_eq!(resp.peers_consulted, 0);
1114 assert!(resp.partial);
1115 }
1116
1117 #[tokio::test]
1118 async fn broadcast_one_peer_returns_local_top_k() {
1119 let mut per_peer: HashMap<PeerId, Vec<HitWithScore>> = HashMap::new();
1120 per_peer.insert(
1121 7,
1122 vec![
1123 HitWithScore {
1124 doc_id: b"a".to_vec(),
1125 score: 0.10,
1126 },
1127 HitWithScore {
1128 doc_id: b"b".to_vec(),
1129 score: 0.30,
1130 },
1131 ],
1132 );
1133 let resp = broadcast(
1134 knn_request(2),
1135 vec![7],
1136 fixed_probe(per_peer),
1137 Duration::from_millis(200),
1138 MergeOrder::ScoreAscending,
1139 )
1140 .await
1141 .unwrap();
1142 assert_eq!(resp.peers_consulted, 1);
1143 assert_eq!(resp.peers_timed_out, 0);
1144 assert!(!resp.partial);
1145 assert_eq!(resp.hits.len(), 2);
1146 assert_eq!(resp.hits[0].doc_id, b"a");
1147 }
1148
1149 #[tokio::test]
1150 async fn broadcast_two_peers_merges() {
1151 let mut per_peer: HashMap<PeerId, Vec<HitWithScore>> = HashMap::new();
1152 per_peer.insert(
1153 1,
1154 vec![
1155 HitWithScore {
1156 doc_id: b"a".to_vec(),
1157 score: 0.10,
1158 },
1159 HitWithScore {
1160 doc_id: b"b".to_vec(),
1161 score: 0.40,
1162 },
1163 ],
1164 );
1165 per_peer.insert(
1166 2,
1167 vec![
1168 HitWithScore {
1169 doc_id: b"c".to_vec(),
1170 score: 0.05,
1171 },
1172 HitWithScore {
1173 doc_id: b"d".to_vec(),
1174 score: 0.50,
1175 },
1176 ],
1177 );
1178 let resp = broadcast(
1179 knn_request(3),
1180 vec![1, 2],
1181 fixed_probe(per_peer),
1182 Duration::from_millis(200),
1183 MergeOrder::ScoreAscending,
1184 )
1185 .await
1186 .unwrap();
1187 assert_eq!(resp.peers_consulted, 2);
1188 assert_eq!(resp.hits.len(), 3);
1189 assert_eq!(resp.hits[0].doc_id, b"c");
1190 assert_eq!(resp.hits[1].doc_id, b"a");
1191 assert_eq!(resp.hits[2].doc_id, b"b");
1192 }
1193
1194 #[tokio::test]
1195 async fn broadcast_one_peer_timeout_marks_partial() {
1196 let probe: AsyncPeerProbe = Arc::new(move |peer, _req| {
1197 Box::pin(async move {
1198 if peer == 9 {
1199 tokio::time::sleep(Duration::from_millis(500)).await;
1200 Ok(Vec::new())
1201 } else {
1202 Ok(vec![HitWithScore {
1203 doc_id: b"x".to_vec(),
1204 score: 0.10,
1205 }])
1206 }
1207 })
1208 });
1209 let resp = broadcast(
1210 knn_request(3),
1211 vec![1, 9],
1212 probe,
1213 Duration::from_millis(50),
1214 MergeOrder::ScoreAscending,
1215 )
1216 .await
1217 .unwrap();
1218 assert_eq!(resp.peers_consulted, 2);
1219 assert_eq!(resp.peers_timed_out, 1);
1220 assert!(resp.partial);
1221 assert_eq!(resp.hits.len(), 1);
1222 assert_eq!(resp.hits[0].doc_id, b"x");
1223 }
1224
1225 #[tokio::test]
1226 async fn broadcast_all_peers_timeout_returns_empty_partial() {
1227 let probe: AsyncPeerProbe = Arc::new(|_peer, _req| {
1228 Box::pin(async move {
1229 tokio::time::sleep(Duration::from_millis(500)).await;
1230 Ok(Vec::new())
1231 })
1232 });
1233 let resp = broadcast(
1234 knn_request(3),
1235 vec![1, 2, 3],
1236 probe,
1237 Duration::from_millis(40),
1238 MergeOrder::ScoreAscending,
1239 )
1240 .await
1241 .unwrap();
1242 assert_eq!(resp.peers_consulted, 3);
1243 assert_eq!(resp.peers_timed_out, 3);
1244 assert!(resp.partial);
1245 assert!(resp.hits.is_empty());
1246 }
1247
1248 #[tokio::test]
1249 async fn select_primary_peers_returns_one_per_distinct_alive_peer() {
1250 let cs = ClusterState::new(
1252 vec![
1253 RingPoint::new(100, 0),
1254 RingPoint::new(200, 1),
1255 RingPoint::new(300, 2),
1256 ],
1257 [0u32, 1, 2].into_iter().collect::<HashSet<_>>(),
1258 );
1259 let mut peers = select_primary_peers(&cs);
1260 peers.sort_unstable();
1261 assert_eq!(peers, vec![0, 1, 2]);
1262 }
1263
1264 #[tokio::test]
1265 async fn select_primary_peers_filters_dead_peers() {
1266 let cs = ClusterState::new(
1267 vec![
1268 RingPoint::new(100, 0),
1269 RingPoint::new(200, 1),
1270 RingPoint::new(300, 2),
1271 ],
1272 [0u32, 2].into_iter().collect::<HashSet<_>>(),
1274 );
1275 let mut peers = select_primary_peers(&cs);
1276 peers.sort_unstable();
1277 assert_eq!(peers, vec![0, 2]);
1278 }
1279
1280 #[tokio::test]
1281 async fn select_primary_peers_dedups_multi_vnode_peers() {
1282 let cs = ClusterState::new(
1285 vec![
1286 RingPoint::new(100, 0),
1287 RingPoint::new(200, 0),
1288 RingPoint::new(300, 1),
1289 ],
1290 [0u32, 1].into_iter().collect::<HashSet<_>>(),
1291 );
1292 let mut peers = select_primary_peers(&cs);
1293 peers.sort_unstable();
1294 assert_eq!(peers, vec![0, 1]);
1295 }
1296}