Skip to main content

dynomite/vector/
query_fsm.rs

1//! Distributed k-NN coordinator.
2//!
3//! When an FT.SEARCH command lands on any node, the query is
4//! broadcast to every primary peer covering the index's key
5//! range, each peer runs the search against its local HNSW
6//! index, and the coordinator merges the per-peer top-K
7//! results.
8//!
9//! The coordinator is shaped as a [`gen_fsm::FsmHandler`] state
10//! machine with four states:
11//!
12//! ```text
13//!     Init  ->  Fanout  ->  Gather  ->  Merge  ->  (stopped)
14//! ```
15//!
16//! State responsibilities:
17//!
18//! * [`State::Init`]: receives the [`SearchRequest`], chooses the
19//!   peer set, and posts a [`Event::Fanout`] internal event to
20//!   move on.
21//! * [`State::Fanout`]: forwards the request to each peer via the
22//!   supplied [`PeerProbe`] and posts a [`Event::Gather`] event.
23//! * [`State::Gather`]: receives [`Event::PeerHits`] events. Once
24//!   either every peer has replied or the deadline elapses, it
25//!   moves to [`State::Merge`].
26//! * [`State::Merge`]: collapses the per-peer hits down to a
27//!   global top-K and stashes the result on the response cell
28//!   the caller holds.
29//!
30//! The coordinator does not perform any I/O on its own; the
31//! [`PeerProbe`] callback is supplied by the caller and is
32//! responsible for actually contacting peers. This keeps the
33//! FSM testable in-process without standing up a real cluster.
34//!
35//! Phase B (this commit) places the FSM under
36//! `dynomite::vector` so the future Phase C wiring can connect
37//! it to the existing [`crate::cluster::apl`] preference-list
38//! walker and [`crate::cluster::vnode::dispatch`] without a
39//! cross-crate dependency. The [`PeerProbe`] callback remains
40//! the integration seam.
41
42use std::collections::{HashMap, HashSet};
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::time::Duration;
47
48use gen_fsm::{Action, EventType, FsmDriver, FsmHandler, Transition};
49use parking_lot::Mutex;
50use serde::{Deserialize, Serialize};
51use tokio::sync::mpsc;
52
53use dynvec::SearchResult;
54
55use crate::cluster::apl::{walk_n_successors, ClusterState};
56use crate::embed::events::PeerId;
57
58/// Default per-peer deadline applied by [`broadcast`].
59///
60/// 5 seconds matches the operational target captured in the
61/// PLAN.md FT.SEARCH wire ticket. Operators that prefer a
62/// shorter or longer ceiling pass an explicit
63/// [`Duration`] to [`broadcast`]; tests typically use a much
64/// smaller value to avoid slowing the suite.
65pub const DEFAULT_PER_PEER_DEADLINE_MS: u64 = 5_000;
66
67/// One per-peer reply.
68#[derive(Clone, Debug, PartialEq)]
69pub struct PeerHits {
70    /// Identifier of the peer that produced the hits.
71    pub peer: String,
72    /// Hits returned by that peer's local search, already sorted
73    /// closest-first.
74    pub hits: Vec<SearchResult>,
75}
76
77/// k-NN query request. The coordinator does not interpret
78/// `vector` directly; that is the caller's job (the
79/// [`PeerProbe`] receives the entire request unchanged).
80#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
81pub struct SearchRequest {
82    /// Index name (the FT.CREATE first argument).
83    pub table: String,
84    /// Query vector in `f32`.
85    pub vector: Vec<f32>,
86    /// Number of results to return.
87    pub k: usize,
88    /// Optional override of the index's default `ef_search`.
89    pub ef: Option<usize>,
90}
91
92/// Final response sent back to the client.
93#[derive(Clone, Debug, PartialEq)]
94pub struct SearchResponse {
95    /// Top-K hits across the whole cluster.
96    pub hits: Vec<SearchResult>,
97    /// Number of peers whose replies were folded in.
98    pub peers_consulted: usize,
99}
100
101/// Type-erased peer probe. Returns the per-peer hit list for
102/// `request`, or an error message if the peer is unreachable.
103pub type PeerProbe =
104    Arc<dyn Fn(&str, SearchRequest) -> Result<Vec<SearchResult>, String> + Send + Sync + 'static>;
105
106/// FSM event types.
107#[derive(Debug)]
108pub enum Event {
109    /// Internal: move from Init -> Fanout.
110    Fanout,
111    /// Internal: move from Fanout -> Gather.
112    Gather,
113    /// External: a peer's search completed.
114    PeerHits(PeerHits),
115    /// Internal: every peer has replied; move to Merge.
116    GatherComplete,
117}
118
119/// FSM states.
120#[derive(Clone, Copy, Debug, PartialEq, Eq)]
121pub enum State {
122    /// Pre-fanout: validating request shape.
123    Init,
124    /// Issuing requests to peers.
125    Fanout,
126    /// Waiting for replies.
127    Gather,
128    /// Producing the merged result.
129    Merge,
130}
131
132/// Coordinator handler. One instance is bound to one in-flight
133/// query; finalising the FSM produces a [`SearchResponse`].
134pub struct Coordinator {
135    request: SearchRequest,
136    peers: Vec<String>,
137    probe: PeerProbe,
138    hits: HashMap<String, Vec<SearchResult>>,
139    response: Arc<Mutex<Option<SearchResponse>>>,
140    /// Optional deadline; if any peer fails to reply by this
141    /// duration after Fanout, the coordinator merges what it
142    /// has.
143    deadline: Duration,
144}
145
146impl Coordinator {
147    /// Build a new coordinator. `peers` is the list of peer
148    /// identifiers the request will fan out to; `probe` is
149    /// invoked synchronously per peer to fetch hits.
150    ///
151    /// The coordinator's `peers_consulted` field on the eventual
152    /// response counts the number of peers that returned hits
153    /// (errors are logged through `tracing::warn!` and
154    /// otherwise dropped).
155    #[must_use]
156    pub fn new(
157        request: SearchRequest,
158        peers: Vec<String>,
159        probe: PeerProbe,
160        deadline: Duration,
161    ) -> (Self, Arc<Mutex<Option<SearchResponse>>>) {
162        let response = Arc::new(Mutex::new(None));
163        let coord = Self {
164            request,
165            peers,
166            probe,
167            hits: HashMap::new(),
168            response: Arc::clone(&response),
169            deadline,
170        };
171        (coord, response)
172    }
173}
174
175impl FsmHandler for Coordinator {
176    type State = State;
177    type Event = Event;
178    type Reply = ();
179    type Stop = String;
180
181    fn initial(&self) -> Self::State {
182        State::Init
183    }
184
185    fn handle(
186        &mut self,
187        state: Self::State,
188        _event_type: EventType,
189        event: Self::Event,
190    ) -> Transition<Self> {
191        match (state, event) {
192            (State::Init, Event::Fanout) => {
193                Transition::Next(State::Fanout, vec![Action::post_internal(Event::Gather)])
194            }
195            (State::Fanout, Event::Gather) => {
196                // Issue probes synchronously, post per-peer hits
197                // back on the FSM mailbox.
198                let mut completion: Vec<Action<Self>> = Vec::new();
199                for peer in self.peers.clone() {
200                    let res = (self.probe)(&peer, self.request.clone());
201                    match res {
202                        Ok(hits) => {
203                            completion.push(Action::post_internal(Event::PeerHits(PeerHits {
204                                peer,
205                                hits,
206                            })));
207                        }
208                        Err(err) => {
209                            tracing::warn!(peer=%peer, error=%err, "peer probe failed");
210                            // Record an empty reply so the
211                            // gather predicate still terminates.
212                            completion.push(Action::post_internal(Event::PeerHits(PeerHits {
213                                peer,
214                                hits: Vec::new(),
215                            })));
216                        }
217                    }
218                }
219                completion.push(Action::set_state_timeout(self.deadline));
220                if completion.is_empty() {
221                    Transition::Next(
222                        State::Merge,
223                        vec![Action::post_internal(Event::GatherComplete)],
224                    )
225                } else {
226                    Transition::Next(State::Gather, completion)
227                }
228            }
229            (State::Gather, Event::PeerHits(reply)) => {
230                self.hits.insert(reply.peer, reply.hits);
231                if self.hits.len() >= self.peers.len() {
232                    Transition::Next(
233                        State::Merge,
234                        vec![Action::post_internal(Event::GatherComplete)],
235                    )
236                } else {
237                    Transition::Keep(vec![])
238                }
239            }
240            (State::Merge, Event::GatherComplete) => {
241                let merged = merge_hits(&self.hits, self.request.k);
242                let response = SearchResponse {
243                    hits: merged,
244                    peers_consulted: self.hits.values().filter(|h| !h.is_empty()).count(),
245                };
246                *self.response.lock() = Some(response);
247                Transition::Stop("complete".to_string())
248            }
249            // Defensive: ignore stray events rather than panicking.
250            (_, _) => Transition::Keep(vec![]),
251        }
252    }
253
254    fn on_timeout(&mut self, state: Self::State, _kind: gen_fsm::TimeoutKind) -> Transition<Self> {
255        match state {
256            State::Gather => Transition::Next(
257                State::Merge,
258                vec![Action::post_internal(Event::GatherComplete)],
259            ),
260            _ => Transition::Keep(vec![]),
261        }
262    }
263}
264
265/// Merge per-peer hit lists into a global top-K.
266///
267/// Each per-peer list is assumed to be sorted closest-first.
268/// The merge is a heap-of-iterators: O((P*K) log P).
269#[must_use]
270pub fn merge_hits<S: std::hash::BuildHasher>(
271    per_peer: &HashMap<String, Vec<SearchResult>, S>,
272    k: usize,
273) -> Vec<SearchResult> {
274    let mut all: Vec<SearchResult> = per_peer.values().flatten().cloned().collect();
275    all.sort_by(|a, b| {
276        a.score
277            .partial_cmp(&b.score)
278            .unwrap_or(std::cmp::Ordering::Equal)
279    });
280    // Deduplicate by id; if the same id appears in multiple
281    // peers' replies (duplicate replication), keep the smallest
282    // score.
283    let mut seen: HashMap<u64, f32> = HashMap::new();
284    let mut deduped: Vec<SearchResult> = Vec::with_capacity(all.len());
285    for r in all {
286        let entry = seen.entry(r.id).or_insert(r.score);
287        if r.score <= *entry {
288            *entry = r.score;
289            deduped.push(r);
290        }
291    }
292    // After dedup, re-sort and take top-k. The deduped vec may
293    // have multiple entries for the same id (one per peer that
294    // returned it); the dedup step below keeps only the
295    // first occurrence of each id, which is the lowest-scored
296    // because the input was sorted.
297    deduped.sort_by(|a, b| {
298        a.score
299            .partial_cmp(&b.score)
300            .unwrap_or(std::cmp::Ordering::Equal)
301    });
302    let mut final_seen: std::collections::HashSet<u64> = std::collections::HashSet::new();
303    let mut out: Vec<SearchResult> = Vec::with_capacity(k);
304    for r in deduped {
305        if final_seen.insert(r.id) {
306            out.push(r);
307            if out.len() >= k {
308                break;
309            }
310        }
311    }
312    out
313}
314
315/// Drive the coordinator to completion.
316///
317/// This is the public entry point: a caller with a
318/// [`SearchRequest`], a peer list, and a [`PeerProbe`] can build
319/// the FSM, post the initial event, wait for completion, and
320/// extract the [`SearchResponse`].
321///
322/// # Errors
323///
324/// Surfaces any [`gen_fsm::DriverError`] from the underlying
325/// FSM driver.
326pub async fn run(
327    request: SearchRequest,
328    peers: Vec<String>,
329    probe: PeerProbe,
330    deadline: Duration,
331) -> Result<SearchResponse, gen_fsm::DriverError> {
332    let (coord, response) = Coordinator::new(request, peers, probe, deadline);
333    let driver = gen_fsm::FsmDriver::start(coord);
334    driver.cast_checked(Event::Fanout).await?;
335    let _stop = driver.join().await?;
336    let final_resp = response.lock().clone().unwrap_or(SearchResponse {
337        hits: Vec::new(),
338        peers_consulted: 0,
339    });
340    Ok(final_resp)
341}
342
343// =====================================================================
344// Cluster-coordinated FT.SEARCH coordinator.
345// =====================================================================
346//
347// The block below extends the original local-only [`Coordinator`]
348// with a properly distributed broadcast path that:
349//
350//   * fans out the request to every primary peer covering the
351//     index's key range;
352//   * applies a per-peer deadline (each peer is timed out
353//     independently of the others);
354//   * merges per-peer top-K lists with explicit ranking
355//     (score-ascending for k-NN, doc-id-ascending for the
356//     trigram and regex text paths);
357//   * surfaces partial results when one or more peers time out,
358//     rather than failing the whole query.
359//
360// The state machine still uses [`gen_fsm`]: the orchestrator
361// spawns one task per peer (each task wraps the probe in a
362// [`tokio::time::timeout`]), and each task posts a
363// [`BroadcastEvent::PeerReplied`] event back to the FSM. The FSM
364// transitions Init -> Gathering(N/n) -> Merging -> Done as the
365// per-peer replies come in; an overall safety-net deadline
366// transitions Gathering -> Merging-with-partial.
367
368/// Wire-shape-friendly representation of the FT.SEARCH query
369/// the coordinator broadcasts to peers.
370///
371/// The coordinator does not interpret the contents itself; per-peer
372/// query execution decodes the variant and runs the matching local
373/// path (k-NN against the HNSW engine, trigram substring match
374/// against the inverted index, or TRE-backed regex match).
375#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
376pub enum SerializedQuery {
377    /// `FT.SEARCH idx "*=>[KNN k @field $param]"` form.
378    Knn {
379        /// Schema vector field name (the `@field` token).
380        vector_field: String,
381        /// Raw little-endian f32 query bytes.
382        vector_bytes: Vec<u8>,
383        /// Optional override of the index's default `ef_search`.
384        ef: Option<u32>,
385    },
386    /// `FT.SEARCH idx "@field:substring"` form.
387    Text {
388        /// Schema TEXT field name.
389        field: String,
390        /// Raw substring bytes.
391        query: Vec<u8>,
392    },
393    /// `FT.REGEX idx field pattern [K=n]` (Dynomite extension).
394    Regex {
395        /// Schema TEXT field name.
396        field: String,
397        /// POSIX-extended regex pattern.
398        pattern: String,
399        /// `K=` parameter; zero selects the exact-regex path.
400        max_errors: u16,
401    },
402}
403
404/// One cluster-wide FT.SEARCH hit.
405///
406/// Distinct from [`SearchResult`] (which uses an HNSW-internal
407/// `u64` id): the cluster coordinator works in user-visible
408/// document keys because the same logical document may sit on
409/// different peers under different internal ids.
410#[derive(Clone, Debug, Default, PartialEq)]
411pub struct HitWithScore {
412    /// User-visible document key (the HSET `key` argument).
413    pub doc_id: Vec<u8>,
414    /// Distance score (smaller is closer for k-NN; ignored
415    /// when [`MergeOrder::DocIdAscending`] is in effect).
416    pub score: f32,
417}
418
419/// One peer's reply to a broadcast.
420///
421/// `timed_out == true` is the protocol's explicit signal that
422/// the per-peer deadline elapsed before the peer produced a
423/// reply; the coordinator counts these toward
424/// [`BroadcastResponse::peers_timed_out`] and tags the result
425/// as partial.
426#[derive(Clone, Debug, Default, PartialEq)]
427pub struct PeerReply {
428    /// Per-peer top-K, already sorted by the peer.
429    pub hits: Vec<HitWithScore>,
430    /// True when the per-peer deadline elapsed.
431    pub timed_out: bool,
432}
433
434/// Cluster-wide FT.SEARCH request.
435///
436/// Crosses the wire as the payload of a
437/// [`crate::proto::dnode::DmsgType::FtSearchReq`] frame; see
438/// [`super::wire`] for the codec.
439#[derive(Clone, Debug, PartialEq, Eq)]
440pub struct BroadcastRequest {
441    /// Index name (the FT.CREATE first argument).
442    pub table: String,
443    /// Encoded query body.
444    pub query: SerializedQuery,
445    /// Number of results to return.
446    pub top_k: u32,
447}
448
449/// Cluster-wide FT.SEARCH response.
450///
451/// Returned by [`broadcast`]. The `partial` flag is true when
452/// at least one peer timed out; the client surfaces this as a
453/// `+WARNING` (today the test rig asserts on the flag rather
454/// than the wire-level marker).
455#[derive(Clone, Debug, Default, PartialEq)]
456pub struct BroadcastResponse {
457    /// Merged global top-K.
458    pub hits: Vec<HitWithScore>,
459    /// Number of peers whose replies were folded in (any peer
460    /// that returned even an empty reply within the deadline).
461    pub peers_consulted: usize,
462    /// Number of peers whose per-peer deadline elapsed.
463    pub peers_timed_out: usize,
464    /// True when at least one peer timed out and the merged
465    /// result therefore covers a strict subset of the cluster.
466    pub partial: bool,
467}
468
469/// Tie-breaking and primary-sort policy applied by
470/// [`merge_hits_ranked`] / [`broadcast`].
471#[derive(Clone, Copy, Debug, PartialEq, Eq)]
472pub enum MergeOrder {
473    /// Smallest-score first; ties broken by `doc_id` ASC.
474    /// Used by the k-NN path (smaller distance is closer).
475    ScoreAscending,
476    /// `doc_id` ASC, ignoring score. Used by the text and
477    /// regex paths, where peers return matches without a
478    /// score and a deterministic ordering is enough for the
479    /// client.
480    DocIdAscending,
481}
482
483/// Async per-peer probe callback.
484///
485/// Production wiring builds this on top of the dnode peer
486/// channel (encode the [`BroadcastRequest`], send the resulting
487/// [`crate::proto::dnode::DmsgType::FtSearchReq`] frame, await
488/// the matching [`crate::proto::dnode::DmsgType::FtSearchRep`],
489/// decode and return the hits). Tests pass an in-memory
490/// callback that simulates per-peer behaviour without standing
491/// up real connections.
492pub type AsyncPeerProbe = Arc<
493    dyn Fn(
494            PeerId,
495            BroadcastRequest,
496        )
497            -> Pin<Box<dyn Future<Output = Result<Vec<HitWithScore>, String>> + Send + 'static>>
498        + Send
499        + Sync
500        + 'static,
501>;
502
503/// Pick one peer per primary token range covered by the local
504/// FT.SEARCH coordinator.
505///
506/// The walker traverses the ring once starting at token 0 and
507/// dedups by peer id, so each canonical-owner peer is visited
508/// exactly once regardless of how many vnodes it owns. Down
509/// peers (those not in `cluster.alive`) are filtered out so the
510/// caller never blocks waiting for a peer the failure detector
511/// already gave up on.
512///
513/// The returned vector is in walk order, which is deterministic
514/// for a given ring + liveness snapshot. Callers that want a
515/// specific ordering for stability tests can sort the returned
516/// vector.
517///
518/// # Examples
519///
520/// ```
521/// use std::collections::HashSet;
522/// use dynomite::cluster::apl::{ClusterState, RingPoint};
523/// use dynomite::vector::query_fsm::select_primary_peers;
524/// let cs = ClusterState::new(
525///     vec![
526///         RingPoint::new(100, 0),
527///         RingPoint::new(200, 1),
528///         RingPoint::new(300, 2),
529///     ],
530///     [0u32, 1, 2].into_iter().collect::<HashSet<_>>(),
531/// );
532/// assert_eq!(select_primary_peers(&cs).len(), 3);
533/// ```
534#[must_use]
535pub fn select_primary_peers(cluster: &ClusterState) -> Vec<PeerId> {
536    let len = cluster.ring().len();
537    if len == 0 {
538        return Vec::new();
539    }
540    walk_n_successors(cluster, 0, len)
541        .into_iter()
542        .filter(|(_, pid)| cluster.is_alive(*pid))
543        .map(|(_, pid)| pid)
544        .collect()
545}
546
547/// Default per-peer fanout deadline.
548#[must_use]
549pub const fn default_per_peer_deadline() -> Duration {
550    Duration::from_millis(DEFAULT_PER_PEER_DEADLINE_MS)
551}
552
553/// Merge per-peer hit lists into a global top-K ordered by
554/// `order`.
555///
556/// Each per-peer list is assumed to be sorted by the peer in
557/// the same order; the merge re-sorts the union and keeps the
558/// first `top_k` entries after deduplicating by `doc_id`. For
559/// [`MergeOrder::ScoreAscending`] duplicate doc ids keep the
560/// smallest score; for [`MergeOrder::DocIdAscending`] duplicate
561/// doc ids are simply elided.
562///
563/// `top_k` of zero returns an empty vector. Empty per-peer
564/// lists contribute nothing.
565///
566/// # Examples
567///
568/// ```
569/// use dynomite::vector::query_fsm::{
570///     merge_hits_ranked, HitWithScore, MergeOrder, PeerReply,
571/// };
572/// let p1 = PeerReply {
573///     hits: vec![HitWithScore { doc_id: b"a".to_vec(), score: 0.1 }],
574///     timed_out: false,
575/// };
576/// let p2 = PeerReply {
577///     hits: vec![HitWithScore { doc_id: b"b".to_vec(), score: 0.05 }],
578///     timed_out: false,
579/// };
580/// let merged = merge_hits_ranked(&[p1, p2], 2, MergeOrder::ScoreAscending);
581/// assert_eq!(merged[0].doc_id, b"b");
582/// assert_eq!(merged[1].doc_id, b"a");
583/// ```
584#[must_use]
585pub fn merge_hits_ranked(
586    per_peer: &[PeerReply],
587    top_k: u32,
588    order: MergeOrder,
589) -> Vec<HitWithScore> {
590    let cap = usize::try_from(top_k).unwrap_or(usize::MAX);
591    if cap == 0 {
592        return Vec::new();
593    }
594    let mut all: Vec<HitWithScore> = per_peer
595        .iter()
596        .flat_map(|reply| reply.hits.iter().cloned())
597        .collect();
598    sort_hits(&mut all, order);
599    let mut seen: HashSet<Vec<u8>> = HashSet::with_capacity(all.len().min(cap));
600    let mut out: Vec<HitWithScore> = Vec::with_capacity(cap);
601    for hit in all {
602        if seen.insert(hit.doc_id.clone()) {
603            out.push(hit);
604            if out.len() >= cap {
605                break;
606            }
607        }
608    }
609    out
610}
611
612fn sort_hits(hits: &mut [HitWithScore], order: MergeOrder) {
613    match order {
614        MergeOrder::ScoreAscending => {
615            hits.sort_by(|a, b| {
616                a.score
617                    .partial_cmp(&b.score)
618                    .unwrap_or(std::cmp::Ordering::Equal)
619                    .then_with(|| a.doc_id.cmp(&b.doc_id))
620            });
621        }
622        MergeOrder::DocIdAscending => {
623            hits.sort_by(|a, b| a.doc_id.cmp(&b.doc_id));
624        }
625    }
626}
627
628// ---- Distributed broadcast FSM ----------------------------------------
629
630/// Events consumed by the distributed broadcast FSM.
631#[derive(Debug)]
632pub enum BroadcastEvent {
633    /// One peer's reply (success, application error, or timeout).
634    PeerReplied(PeerReply),
635    /// Internal: every peer has reported back; transition
636    /// from [`BroadcastState::Gathering`] to
637    /// [`BroadcastState::Merging`].
638    AllReceived,
639    /// Internal: merge has produced the final response;
640    /// transition from [`BroadcastState::Merging`] to
641    /// terminal stop.
642    MergeDone,
643}
644
645/// States of the distributed broadcast FSM.
646#[derive(Clone, Copy, Debug, PartialEq, Eq)]
647pub enum BroadcastState {
648    /// Pre-fanout: waiting for the orchestrator to start
649    /// posting replies.
650    Init,
651    /// Receiving per-peer replies. Counters live on the FSM
652    /// data; the variant itself is parameter-free so it stays
653    /// `Copy`.
654    Gathering,
655    /// Merging the per-peer hit lists into the global top-K.
656    Merging,
657}
658
659/// Distributed broadcast coordinator.
660///
661/// Holds the FSM data: the request, the running list of
662/// per-peer replies, the merge order, the response cell, and
663/// the peer count needed to detect completion.
664pub struct BroadcastCoordinator {
665    request: BroadcastRequest,
666    expected_peers: usize,
667    replies: Vec<PeerReply>,
668    order: MergeOrder,
669    response: Arc<Mutex<Option<BroadcastResponse>>>,
670    overall_deadline: Duration,
671}
672
673impl BroadcastCoordinator {
674    /// Construct a fresh coordinator.
675    #[must_use]
676    pub fn new(
677        request: BroadcastRequest,
678        expected_peers: usize,
679        order: MergeOrder,
680        overall_deadline: Duration,
681    ) -> (Self, Arc<Mutex<Option<BroadcastResponse>>>) {
682        let response = Arc::new(Mutex::new(None));
683        let coord = Self {
684            request,
685            expected_peers,
686            replies: Vec::with_capacity(expected_peers),
687            order,
688            response: Arc::clone(&response),
689            overall_deadline,
690        };
691        (coord, response)
692    }
693
694    fn finalise(&self) -> BroadcastResponse {
695        let timed_out = self.replies.iter().filter(|r| r.timed_out).count();
696        let consulted = self.replies.len();
697        let merged = merge_hits_ranked(&self.replies, self.request.top_k, self.order);
698        BroadcastResponse {
699            hits: merged,
700            peers_consulted: consulted,
701            peers_timed_out: timed_out,
702            partial: timed_out > 0 || consulted < self.expected_peers,
703        }
704    }
705}
706
707impl FsmHandler for BroadcastCoordinator {
708    type State = BroadcastState;
709    type Event = BroadcastEvent;
710    type Reply = ();
711    type Stop = String;
712
713    fn initial(&self) -> Self::State {
714        BroadcastState::Init
715    }
716
717    fn handle(
718        &mut self,
719        state: Self::State,
720        _event_type: EventType,
721        event: Self::Event,
722    ) -> Transition<Self> {
723        match (state, event) {
724            (BroadcastState::Init | BroadcastState::Gathering, BroadcastEvent::PeerReplied(r)) => {
725                self.replies.push(r);
726                if self.replies.len() >= self.expected_peers {
727                    Transition::Next(
728                        BroadcastState::Merging,
729                        vec![
730                            Action::cancel_state_timeout(),
731                            Action::post_internal(BroadcastEvent::AllReceived),
732                        ],
733                    )
734                } else if state == BroadcastState::Init {
735                    Transition::Next(
736                        BroadcastState::Gathering,
737                        vec![Action::set_state_timeout(self.overall_deadline)],
738                    )
739                } else {
740                    Transition::Keep(vec![])
741                }
742            }
743            (BroadcastState::Merging, BroadcastEvent::AllReceived | BroadcastEvent::MergeDone) => {
744                let resp = self.finalise();
745                *self.response.lock() = Some(resp);
746                Transition::Stop("broadcast complete".to_string())
747            }
748            // Any other (state, event) pair is benign: stray
749            // PeerReplied frames after the merge already kicked
750            // off, or AllReceived posted a second time by a
751            // racing internal event. We swallow them rather
752            // than panicking.
753            _ => Transition::Keep(vec![]),
754        }
755    }
756
757    fn on_timeout(&mut self, state: Self::State, _kind: gen_fsm::TimeoutKind) -> Transition<Self> {
758        if matches!(state, BroadcastState::Gathering | BroadcastState::Init) {
759            // Synthesise a timed-out reply for every still-missing
760            // peer so the merge knows the broadcast is partial.
761            while self.replies.len() < self.expected_peers {
762                self.replies.push(PeerReply {
763                    hits: Vec::new(),
764                    timed_out: true,
765                });
766            }
767            Transition::Next(
768                BroadcastState::Merging,
769                vec![Action::post_internal(BroadcastEvent::AllReceived)],
770            )
771        } else {
772            Transition::Keep(vec![])
773        }
774    }
775}
776
777/// Drive the distributed FT.SEARCH coordinator to completion.
778///
779/// `peers` is the list of peer ids the request will be
780/// broadcast to; build it via [`select_primary_peers`] from a
781/// [`crate::cluster::apl::ClusterState`] in production. `probe`
782/// is invoked once per peer and is responsible for actually
783/// running the per-peer search (in production, by serialising
784/// the request via [`super::wire::encode_request`] and writing
785/// it down the dnode peer channel).
786///
787/// Each per-peer probe is wrapped in a
788/// [`tokio::time::timeout`] of `per_peer_deadline`. A timed-out
789/// peer contributes an empty [`PeerReply`] flagged
790/// `timed_out = true`; it does not abort the broadcast.
791///
792/// `order` selects the merge ranking: pass
793/// [`MergeOrder::ScoreAscending`] for the k-NN path (smaller
794/// distance is closer) or [`MergeOrder::DocIdAscending`] for
795/// the trigram and regex text paths.
796///
797/// Returns a [`BroadcastResponse`] whose `partial` flag is
798/// `true` when at least one peer timed out (or when no peers
799/// were supplied at all).
800///
801/// # Errors
802///
803/// Surfaces any [`gen_fsm::DriverError`] from the underlying
804/// FSM driver.
805pub async fn broadcast(
806    request: BroadcastRequest,
807    peers: Vec<PeerId>,
808    probe: AsyncPeerProbe,
809    per_peer_deadline: Duration,
810    order: MergeOrder,
811) -> Result<BroadcastResponse, gen_fsm::DriverError> {
812    if peers.is_empty() {
813        return Ok(BroadcastResponse {
814            hits: Vec::new(),
815            peers_consulted: 0,
816            peers_timed_out: 0,
817            partial: true,
818        });
819    }
820    // Overall deadline: a generous safety net above the
821    // per-peer deadline. The coordinator drives termination
822    // off the per-peer fan-in; this only fires if a probe task
823    // panics or the runtime stalls before the per-peer timeout
824    // can elapse.
825    let overall = per_peer_deadline
826        .saturating_mul(2)
827        .saturating_add(Duration::from_secs(1));
828    let n = peers.len();
829    let (handler, response) = BroadcastCoordinator::new(request.clone(), n, order, overall);
830    let driver: FsmDriver<BroadcastCoordinator> = FsmDriver::start(handler);
831    let (reply_tx, mut reply_rx) = mpsc::channel::<PeerReply>(n);
832    for peer in peers {
833        let probe = Arc::clone(&probe);
834        let req = request.clone();
835        let tx = reply_tx.clone();
836        tokio::spawn(async move {
837            let fut = probe(peer, req);
838            let reply = match tokio::time::timeout(per_peer_deadline, fut).await {
839                Ok(Ok(hits)) => PeerReply {
840                    hits,
841                    timed_out: false,
842                },
843                Ok(Err(err)) => {
844                    tracing::warn!(peer=peer, error=%err, "FT.SEARCH peer probe failed");
845                    PeerReply {
846                        hits: Vec::new(),
847                        timed_out: false,
848                    }
849                }
850                Err(_) => {
851                    tracing::warn!(
852                        peer = peer,
853                        "FT.SEARCH peer probe timed out (per-peer deadline elapsed)"
854                    );
855                    PeerReply {
856                        hits: Vec::new(),
857                        timed_out: true,
858                    }
859                }
860            };
861            let _ = tx.send(reply).await;
862        });
863    }
864    drop(reply_tx);
865    let driver_for_pump = driver.clone();
866    let pump = tokio::spawn(async move {
867        while let Some(reply) = reply_rx.recv().await {
868            if driver_for_pump
869                .cast_checked(BroadcastEvent::PeerReplied(reply))
870                .await
871                .is_err()
872            {
873                break;
874            }
875        }
876    });
877    let _ = driver.join().await?;
878    let _ = pump.await;
879    let final_resp = response
880        .lock()
881        .clone()
882        .unwrap_or_else(|| BroadcastResponse {
883            hits: Vec::new(),
884            peers_consulted: 0,
885            peers_timed_out: n,
886            partial: true,
887        });
888    Ok(final_resp)
889}
890
891#[cfg(test)]
892mod tests {
893    use super::*;
894    use dynvec::SearchResult;
895
896    fn req() -> SearchRequest {
897        SearchRequest {
898            table: "t".to_string(),
899            vector: vec![0.0; 4],
900            k: 3,
901            ef: None,
902        }
903    }
904
905    #[tokio::test]
906    async fn merges_hits_from_multiple_peers() {
907        let hits_p1 = vec![
908            SearchResult { id: 1, score: 0.1 },
909            SearchResult { id: 2, score: 0.5 },
910        ];
911        let hits_p2 = vec![
912            SearchResult { id: 3, score: 0.2 },
913            SearchResult { id: 4, score: 0.6 },
914        ];
915        let probe: PeerProbe = Arc::new(move |peer, _r| match peer {
916            "p1" => Ok(hits_p1.clone()),
917            "p2" => Ok(hits_p2.clone()),
918            _ => Err("unknown peer".to_string()),
919        });
920        let resp = run(
921            req(),
922            vec!["p1".to_string(), "p2".to_string()],
923            probe,
924            Duration::from_secs(1),
925        )
926        .await
927        .unwrap();
928        assert_eq!(resp.peers_consulted, 2);
929        assert_eq!(resp.hits.len(), 3);
930        assert_eq!(resp.hits[0].id, 1);
931        assert_eq!(resp.hits[1].id, 3);
932        assert_eq!(resp.hits[2].id, 2);
933    }
934
935    #[tokio::test]
936    async fn missing_peers_are_tolerated() {
937        let probe: PeerProbe = Arc::new(|peer, _r| match peer {
938            "good" => Ok(vec![SearchResult { id: 1, score: 0.1 }]),
939            _ => Err("dead".to_string()),
940        });
941        let resp = run(
942            req(),
943            vec!["good".to_string(), "bad".to_string()],
944            probe,
945            Duration::from_secs(1),
946        )
947        .await
948        .unwrap();
949        assert_eq!(resp.peers_consulted, 1);
950        assert_eq!(resp.hits.len(), 1);
951        assert_eq!(resp.hits[0].id, 1);
952    }
953
954    #[tokio::test]
955    async fn duplicate_ids_collapsed() {
956        let probe: PeerProbe = Arc::new(|peer, _r| match peer {
957            "p1" => Ok(vec![SearchResult { id: 1, score: 0.10 }]),
958            "p2" => Ok(vec![SearchResult { id: 1, score: 0.05 }]),
959            _ => Err("unknown".to_string()),
960        });
961        let resp = run(
962            SearchRequest {
963                table: "t".to_string(),
964                vector: vec![],
965                k: 2,
966                ef: None,
967            },
968            vec!["p1".to_string(), "p2".to_string()],
969            probe,
970            Duration::from_secs(1),
971        )
972        .await
973        .unwrap();
974        assert_eq!(resp.hits.len(), 1);
975        assert!((resp.hits[0].score - 0.05).abs() < 1e-6);
976    }
977
978    // ---- Distributed broadcast FSM tests --------------------
979
980    use std::collections::HashSet;
981
982    use crate::cluster::apl::{ClusterState, RingPoint};
983
984    fn knn_request(top_k: u32) -> BroadcastRequest {
985        BroadcastRequest {
986            table: "idx".into(),
987            query: SerializedQuery::Knn {
988                vector_field: "v".into(),
989                vector_bytes: vec![0u8; 16],
990                ef: None,
991            },
992            top_k,
993        }
994    }
995
996    fn fixed_probe(per_peer: HashMap<PeerId, Vec<HitWithScore>>) -> AsyncPeerProbe {
997        Arc::new(move |peer, _req| {
998            let hits = per_peer.get(&peer).cloned().unwrap_or_default();
999            Box::pin(async move { Ok(hits) })
1000        })
1001    }
1002
1003    #[tokio::test]
1004    async fn merge_score_ascending_picks_smallest_scores() {
1005        let p0 = PeerReply {
1006            hits: vec![
1007                HitWithScore {
1008                    doc_id: b"a".to_vec(),
1009                    score: 0.1,
1010                },
1011                HitWithScore {
1012                    doc_id: b"b".to_vec(),
1013                    score: 0.5,
1014                },
1015            ],
1016            timed_out: false,
1017        };
1018        let p1 = PeerReply {
1019            hits: vec![
1020                HitWithScore {
1021                    doc_id: b"c".to_vec(),
1022                    score: 0.05,
1023                },
1024                HitWithScore {
1025                    doc_id: b"d".to_vec(),
1026                    score: 0.6,
1027                },
1028            ],
1029            timed_out: false,
1030        };
1031        let merged = merge_hits_ranked(&[p0, p1], 3, MergeOrder::ScoreAscending);
1032        assert_eq!(merged.len(), 3);
1033        assert_eq!(merged[0].doc_id, b"c");
1034        assert_eq!(merged[1].doc_id, b"a");
1035        assert_eq!(merged[2].doc_id, b"b");
1036    }
1037
1038    #[tokio::test]
1039    async fn merge_doc_id_ascending_orders_by_key() {
1040        let p0 = PeerReply {
1041            hits: vec![
1042                HitWithScore {
1043                    doc_id: b"key:9".to_vec(),
1044                    score: 0.0,
1045                },
1046                HitWithScore {
1047                    doc_id: b"key:1".to_vec(),
1048                    score: 0.0,
1049                },
1050            ],
1051            timed_out: false,
1052        };
1053        let p1 = PeerReply {
1054            hits: vec![HitWithScore {
1055                doc_id: b"key:5".to_vec(),
1056                score: 0.0,
1057            }],
1058            timed_out: false,
1059        };
1060        let merged = merge_hits_ranked(&[p0, p1], 5, MergeOrder::DocIdAscending);
1061        assert_eq!(
1062            merged.iter().map(|h| h.doc_id.clone()).collect::<Vec<_>>(),
1063            vec![b"key:1".to_vec(), b"key:5".to_vec(), b"key:9".to_vec()],
1064        );
1065    }
1066
1067    #[tokio::test]
1068    async fn merge_dedups_doc_ids_in_score_order() {
1069        let p0 = PeerReply {
1070            hits: vec![HitWithScore {
1071                doc_id: b"a".to_vec(),
1072                score: 0.10,
1073            }],
1074            timed_out: false,
1075        };
1076        let p1 = PeerReply {
1077            hits: vec![HitWithScore {
1078                doc_id: b"a".to_vec(),
1079                score: 0.05,
1080            }],
1081            timed_out: false,
1082        };
1083        let merged = merge_hits_ranked(&[p0, p1], 5, MergeOrder::ScoreAscending);
1084        assert_eq!(merged.len(), 1);
1085        assert!((merged[0].score - 0.05).abs() < 1e-6);
1086    }
1087
1088    #[tokio::test]
1089    async fn merge_top_k_zero_returns_empty() {
1090        let p = PeerReply {
1091            hits: vec![HitWithScore {
1092                doc_id: b"a".to_vec(),
1093                score: 0.1,
1094            }],
1095            timed_out: false,
1096        };
1097        assert!(merge_hits_ranked(&[p], 0, MergeOrder::ScoreAscending).is_empty());
1098    }
1099
1100    #[tokio::test]
1101    async fn broadcast_with_no_peers_returns_partial_empty() {
1102        let probe: AsyncPeerProbe = Arc::new(|_peer, _req| Box::pin(async { Ok(Vec::new()) }));
1103        let resp = broadcast(
1104            knn_request(5),
1105            Vec::new(),
1106            probe,
1107            Duration::from_millis(50),
1108            MergeOrder::ScoreAscending,
1109        )
1110        .await
1111        .unwrap();
1112        assert!(resp.hits.is_empty());
1113        assert_eq!(resp.peers_consulted, 0);
1114        assert!(resp.partial);
1115    }
1116
1117    #[tokio::test]
1118    async fn broadcast_one_peer_returns_local_top_k() {
1119        let mut per_peer: HashMap<PeerId, Vec<HitWithScore>> = HashMap::new();
1120        per_peer.insert(
1121            7,
1122            vec![
1123                HitWithScore {
1124                    doc_id: b"a".to_vec(),
1125                    score: 0.10,
1126                },
1127                HitWithScore {
1128                    doc_id: b"b".to_vec(),
1129                    score: 0.30,
1130                },
1131            ],
1132        );
1133        let resp = broadcast(
1134            knn_request(2),
1135            vec![7],
1136            fixed_probe(per_peer),
1137            Duration::from_millis(200),
1138            MergeOrder::ScoreAscending,
1139        )
1140        .await
1141        .unwrap();
1142        assert_eq!(resp.peers_consulted, 1);
1143        assert_eq!(resp.peers_timed_out, 0);
1144        assert!(!resp.partial);
1145        assert_eq!(resp.hits.len(), 2);
1146        assert_eq!(resp.hits[0].doc_id, b"a");
1147    }
1148
1149    #[tokio::test]
1150    async fn broadcast_two_peers_merges() {
1151        let mut per_peer: HashMap<PeerId, Vec<HitWithScore>> = HashMap::new();
1152        per_peer.insert(
1153            1,
1154            vec![
1155                HitWithScore {
1156                    doc_id: b"a".to_vec(),
1157                    score: 0.10,
1158                },
1159                HitWithScore {
1160                    doc_id: b"b".to_vec(),
1161                    score: 0.40,
1162                },
1163            ],
1164        );
1165        per_peer.insert(
1166            2,
1167            vec![
1168                HitWithScore {
1169                    doc_id: b"c".to_vec(),
1170                    score: 0.05,
1171                },
1172                HitWithScore {
1173                    doc_id: b"d".to_vec(),
1174                    score: 0.50,
1175                },
1176            ],
1177        );
1178        let resp = broadcast(
1179            knn_request(3),
1180            vec![1, 2],
1181            fixed_probe(per_peer),
1182            Duration::from_millis(200),
1183            MergeOrder::ScoreAscending,
1184        )
1185        .await
1186        .unwrap();
1187        assert_eq!(resp.peers_consulted, 2);
1188        assert_eq!(resp.hits.len(), 3);
1189        assert_eq!(resp.hits[0].doc_id, b"c");
1190        assert_eq!(resp.hits[1].doc_id, b"a");
1191        assert_eq!(resp.hits[2].doc_id, b"b");
1192    }
1193
1194    #[tokio::test]
1195    async fn broadcast_one_peer_timeout_marks_partial() {
1196        let probe: AsyncPeerProbe = Arc::new(move |peer, _req| {
1197            Box::pin(async move {
1198                if peer == 9 {
1199                    tokio::time::sleep(Duration::from_millis(500)).await;
1200                    Ok(Vec::new())
1201                } else {
1202                    Ok(vec![HitWithScore {
1203                        doc_id: b"x".to_vec(),
1204                        score: 0.10,
1205                    }])
1206                }
1207            })
1208        });
1209        let resp = broadcast(
1210            knn_request(3),
1211            vec![1, 9],
1212            probe,
1213            Duration::from_millis(50),
1214            MergeOrder::ScoreAscending,
1215        )
1216        .await
1217        .unwrap();
1218        assert_eq!(resp.peers_consulted, 2);
1219        assert_eq!(resp.peers_timed_out, 1);
1220        assert!(resp.partial);
1221        assert_eq!(resp.hits.len(), 1);
1222        assert_eq!(resp.hits[0].doc_id, b"x");
1223    }
1224
1225    #[tokio::test]
1226    async fn broadcast_all_peers_timeout_returns_empty_partial() {
1227        let probe: AsyncPeerProbe = Arc::new(|_peer, _req| {
1228            Box::pin(async move {
1229                tokio::time::sleep(Duration::from_millis(500)).await;
1230                Ok(Vec::new())
1231            })
1232        });
1233        let resp = broadcast(
1234            knn_request(3),
1235            vec![1, 2, 3],
1236            probe,
1237            Duration::from_millis(40),
1238            MergeOrder::ScoreAscending,
1239        )
1240        .await
1241        .unwrap();
1242        assert_eq!(resp.peers_consulted, 3);
1243        assert_eq!(resp.peers_timed_out, 3);
1244        assert!(resp.partial);
1245        assert!(resp.hits.is_empty());
1246    }
1247
1248    #[tokio::test]
1249    async fn select_primary_peers_returns_one_per_distinct_alive_peer() {
1250        // Three peers, each with one ring entry; all alive.
1251        let cs = ClusterState::new(
1252            vec![
1253                RingPoint::new(100, 0),
1254                RingPoint::new(200, 1),
1255                RingPoint::new(300, 2),
1256            ],
1257            [0u32, 1, 2].into_iter().collect::<HashSet<_>>(),
1258        );
1259        let mut peers = select_primary_peers(&cs);
1260        peers.sort_unstable();
1261        assert_eq!(peers, vec![0, 1, 2]);
1262    }
1263
1264    #[tokio::test]
1265    async fn select_primary_peers_filters_dead_peers() {
1266        let cs = ClusterState::new(
1267            vec![
1268                RingPoint::new(100, 0),
1269                RingPoint::new(200, 1),
1270                RingPoint::new(300, 2),
1271            ],
1272            // Peer 1 is dead.
1273            [0u32, 2].into_iter().collect::<HashSet<_>>(),
1274        );
1275        let mut peers = select_primary_peers(&cs);
1276        peers.sort_unstable();
1277        assert_eq!(peers, vec![0, 2]);
1278    }
1279
1280    #[tokio::test]
1281    async fn select_primary_peers_dedups_multi_vnode_peers() {
1282        // Peer 0 has two ring entries (multi-vnode); the
1283        // selector returns it once.
1284        let cs = ClusterState::new(
1285            vec![
1286                RingPoint::new(100, 0),
1287                RingPoint::new(200, 0),
1288                RingPoint::new(300, 1),
1289            ],
1290            [0u32, 1].into_iter().collect::<HashSet<_>>(),
1291        );
1292        let mut peers = select_primary_peers(&cs);
1293        peers.sort_unstable();
1294        assert_eq!(peers, vec![0, 1]);
1295    }
1296}