Skip to main content

dynomite/vector/
wire.rs

1//! On-the-wire codec for cluster-coordinated FT.SEARCH.
2//!
3//! This module defines the binary serialisation format the
4//! coordinator uses when broadcasting an FT.SEARCH to remote
5//! peers via the [`crate::proto::dnode::DmsgType::FtSearchReq`]
6//! and [`crate::proto::dnode::DmsgType::FtSearchRep`] DNODE
7//! frames.
8//!
9//! # Format choice
10//!
11//! The codec is a small, hand-rolled length-prefixed layout
12//! that uses only the standard library, mirroring the
13//! [`crate::proto::dnode::Handshake`] approach. Pulling in a
14//! heavier serde codec was rejected because:
15//!
16//! * the message shapes are tiny and stable;
17//! * the FT.SEARCH path is hot, so allocation and parse cost
18//!   matter;
19//! * keeping the codec in this module keeps the cluster-FT
20//!   surface honest: any new field shows up here and is
21//!   covered by the round-trip tests below.
22//!
23//! All multi-byte integers are little-endian. Lengths are
24//! `u32` so individual fields are bounded at 4 GiB which is
25//! well above any realistic vector / pattern payload.
26//!
27//! # Wire layout
28//!
29//! ## Request (`FtSearchReq`)
30//!
31//! ```text
32//! magic(4)   = "FTQ1"
33//! flags(2)   = 0
34//! top_k(4)   = u32 (LE)
35//! table_len  = u32 (LE)
36//! table      = utf-8 bytes
37//! query_tag  = u8  (0=KNN, 1=Text, 2=Regex)
38//! query body = ...   (depends on tag, see below)
39//! ```
40//!
41//! ## Reply (`FtSearchRep`)
42//!
43//! ```text
44//! magic(4)        = "FTR1"
45//! flags(2)        = 0
46//! timed_out(1)    = 0|1
47//! hit_count(4)    = u32 (LE)
48//! repeat hit_count times:
49//!     doc_id_len  = u32 (LE)
50//!     doc_id      = bytes
51//!     score       = f32 (LE)
52//! ```
53//!
54//! Tag bodies:
55//!
56//! ```text
57//! KNN:    field_len(4) field_utf8 bytes_len(4) vector_bytes
58//!         ef_present(1) [ef(4)]
59//! Text:   field_len(4) field_utf8 query_len(4) query_bytes
60//! Regex:  field_len(4) field_utf8 pattern_len(4) pattern_utf8
61//!         max_errors(2)
62//! ```
63
64use std::convert::TryFrom;
65
66use thiserror::Error;
67
68use super::query_fsm::{BroadcastRequest, HitWithScore, PeerReply, SerializedQuery};
69
70/// Magic literal that opens every encoded
71/// [`BroadcastRequest`] payload.
72pub const REQ_MAGIC: [u8; 4] = *b"FTQ1";
73
74/// Magic literal that opens every encoded
75/// [`PeerReply`] payload.
76pub const REP_MAGIC: [u8; 4] = *b"FTR1";
77
78const TAG_KNN: u8 = 0;
79const TAG_TEXT: u8 = 1;
80const TAG_REGEX: u8 = 2;
81
82/// Errors raised by the cluster-FT codec.
83#[derive(Debug, Error, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum CodecError {
86    /// Payload was shorter than required by the layout.
87    #[error("FT search payload truncated")]
88    Truncated,
89    /// Payload header magic did not match.
90    #[error("FT search payload bad magic")]
91    BadMagic,
92    /// Payload reserved-flags field was non-zero.
93    #[error("FT search payload bad flags")]
94    BadFlags,
95    /// Encoded length exceeds the remaining slice.
96    #[error("FT search field length out of range")]
97    LengthOverflow,
98    /// Embedded UTF-8 string did not parse.
99    #[error("FT search field not utf-8")]
100    BadUtf8,
101    /// Query body tag byte is not one of the known variants.
102    #[error("FT search unknown query tag {0}")]
103    BadTag(u8),
104}
105
106/// Encode a [`BroadcastRequest`] to a binary payload suitable
107/// for the [`crate::proto::dnode::DmsgType::FtSearchReq`]
108/// DNODE frame.
109///
110/// # Examples
111///
112/// ```
113/// use dynomite::vector::query_fsm::{BroadcastRequest, SerializedQuery};
114/// use dynomite::vector::wire::{decode_request, encode_request};
115///
116/// let req = BroadcastRequest {
117///     table: "idx".into(),
118///     query: SerializedQuery::Text {
119///         field: "body".into(),
120///         query: b"foo".to_vec(),
121///     },
122///     top_k: 10,
123/// };
124/// let bytes = encode_request(&req);
125/// let back = decode_request(&bytes).unwrap();
126/// assert_eq!(req, back);
127/// ```
128#[must_use]
129pub fn encode_request(req: &BroadcastRequest) -> Vec<u8> {
130    let mut out = Vec::with_capacity(64);
131    out.extend_from_slice(&REQ_MAGIC);
132    out.extend_from_slice(&0u16.to_le_bytes());
133    out.extend_from_slice(&req.top_k.to_le_bytes());
134    write_bytes(&mut out, req.table.as_bytes());
135    match &req.query {
136        SerializedQuery::Knn {
137            vector_field,
138            vector_bytes,
139            ef,
140        } => {
141            out.push(TAG_KNN);
142            write_bytes(&mut out, vector_field.as_bytes());
143            write_bytes(&mut out, vector_bytes);
144            match ef {
145                Some(value) => {
146                    out.push(1);
147                    out.extend_from_slice(&value.to_le_bytes());
148                }
149                None => out.push(0),
150            }
151        }
152        SerializedQuery::Text { field, query } => {
153            out.push(TAG_TEXT);
154            write_bytes(&mut out, field.as_bytes());
155            write_bytes(&mut out, query);
156        }
157        SerializedQuery::Regex {
158            field,
159            pattern,
160            max_errors,
161        } => {
162            out.push(TAG_REGEX);
163            write_bytes(&mut out, field.as_bytes());
164            write_bytes(&mut out, pattern.as_bytes());
165            out.extend_from_slice(&max_errors.to_le_bytes());
166        }
167    }
168    out
169}
170
171/// Decode a [`BroadcastRequest`] previously produced by
172/// [`encode_request`].
173///
174/// # Errors
175///
176/// Returns [`CodecError`] when the payload is truncated, the
177/// magic header is wrong, or any embedded string is not valid
178/// UTF-8.
179pub fn decode_request(bytes: &[u8]) -> Result<BroadcastRequest, CodecError> {
180    let mut cursor = Cursor::new(bytes);
181    let magic = cursor.take_array::<4>()?;
182    if magic != REQ_MAGIC {
183        return Err(CodecError::BadMagic);
184    }
185    let flags = cursor.take_u16()?;
186    if flags != 0 {
187        return Err(CodecError::BadFlags);
188    }
189    let top_k = cursor.take_u32()?;
190    let table_bytes = cursor.take_bytes()?.to_vec();
191    let table = String::from_utf8(table_bytes).map_err(|_| CodecError::BadUtf8)?;
192    let tag = cursor.take_u8()?;
193    let query = match tag {
194        TAG_KNN => {
195            let field_bytes = cursor.take_bytes()?.to_vec();
196            let vector_field = String::from_utf8(field_bytes).map_err(|_| CodecError::BadUtf8)?;
197            let vector_bytes = cursor.take_bytes()?.to_vec();
198            let ef_present = cursor.take_u8()?;
199            let ef = match ef_present {
200                0 => None,
201                1 => Some(cursor.take_u32()?),
202                _ => return Err(CodecError::BadFlags),
203            };
204            SerializedQuery::Knn {
205                vector_field,
206                vector_bytes,
207                ef,
208            }
209        }
210        TAG_TEXT => {
211            let field_bytes = cursor.take_bytes()?.to_vec();
212            let field = String::from_utf8(field_bytes).map_err(|_| CodecError::BadUtf8)?;
213            let query = cursor.take_bytes()?.to_vec();
214            SerializedQuery::Text { field, query }
215        }
216        TAG_REGEX => {
217            let field_bytes = cursor.take_bytes()?.to_vec();
218            let field = String::from_utf8(field_bytes).map_err(|_| CodecError::BadUtf8)?;
219            let pattern_bytes = cursor.take_bytes()?.to_vec();
220            let pattern = String::from_utf8(pattern_bytes).map_err(|_| CodecError::BadUtf8)?;
221            let max_errors = cursor.take_u16()?;
222            SerializedQuery::Regex {
223                field,
224                pattern,
225                max_errors,
226            }
227        }
228        other => return Err(CodecError::BadTag(other)),
229    };
230    Ok(BroadcastRequest {
231        table,
232        query,
233        top_k,
234    })
235}
236
237/// Encode a [`PeerReply`] (one peer's per-peer top-K) for the
238/// [`crate::proto::dnode::DmsgType::FtSearchRep`] DNODE frame.
239///
240/// # Examples
241///
242/// ```
243/// use dynomite::vector::query_fsm::{HitWithScore, PeerReply};
244/// use dynomite::vector::wire::{decode_reply, encode_reply};
245///
246/// let reply = PeerReply {
247///     hits: vec![HitWithScore {
248///         doc_id: b"key:1".to_vec(),
249///         score: 0.25,
250///     }],
251///     timed_out: false,
252/// };
253/// let bytes = encode_reply(&reply);
254/// let back = decode_reply(&bytes).unwrap();
255/// assert_eq!(reply, back);
256/// ```
257#[must_use]
258pub fn encode_reply(reply: &PeerReply) -> Vec<u8> {
259    let mut out = Vec::with_capacity(32 + reply.hits.len() * 24);
260    out.extend_from_slice(&REP_MAGIC);
261    out.extend_from_slice(&0u16.to_le_bytes());
262    out.push(u8::from(reply.timed_out));
263    let count = u32::try_from(reply.hits.len()).unwrap_or(u32::MAX);
264    out.extend_from_slice(&count.to_le_bytes());
265    let max = count as usize;
266    for hit in reply.hits.iter().take(max) {
267        write_bytes(&mut out, &hit.doc_id);
268        out.extend_from_slice(&hit.score.to_le_bytes());
269    }
270    out
271}
272
273/// Decode a [`PeerReply`] previously produced by
274/// [`encode_reply`].
275///
276/// # Errors
277///
278/// Returns [`CodecError`] when the payload is truncated or the
279/// magic header is wrong.
280pub fn decode_reply(bytes: &[u8]) -> Result<PeerReply, CodecError> {
281    let mut cursor = Cursor::new(bytes);
282    let magic = cursor.take_array::<4>()?;
283    if magic != REP_MAGIC {
284        return Err(CodecError::BadMagic);
285    }
286    let flags = cursor.take_u16()?;
287    if flags != 0 {
288        return Err(CodecError::BadFlags);
289    }
290    let timed_out_byte = cursor.take_u8()?;
291    if timed_out_byte > 1 {
292        return Err(CodecError::BadFlags);
293    }
294    let timed_out = timed_out_byte == 1;
295    let count = cursor.take_u32()?;
296    let count_usize = usize::try_from(count).map_err(|_| CodecError::LengthOverflow)?;
297    let mut hits: Vec<HitWithScore> = Vec::with_capacity(count_usize.min(64));
298    for _ in 0..count_usize {
299        let doc_id = cursor.take_bytes()?.to_vec();
300        let score = cursor.take_f32()?;
301        hits.push(HitWithScore { doc_id, score });
302    }
303    Ok(PeerReply { hits, timed_out })
304}
305
306// ---- helpers -----------------------------------------------------------
307
308fn write_bytes(out: &mut Vec<u8>, bytes: &[u8]) {
309    let len = u32::try_from(bytes.len()).unwrap_or(u32::MAX);
310    out.extend_from_slice(&len.to_le_bytes());
311    let max = len as usize;
312    out.extend_from_slice(&bytes[..bytes.len().min(max)]);
313}
314
315struct Cursor<'a> {
316    buf: &'a [u8],
317    pos: usize,
318}
319
320impl<'a> Cursor<'a> {
321    fn new(buf: &'a [u8]) -> Self {
322        Self { buf, pos: 0 }
323    }
324
325    fn require(&self, want: usize) -> Result<(), CodecError> {
326        if self
327            .pos
328            .checked_add(want)
329            .is_none_or(|end| end > self.buf.len())
330        {
331            return Err(CodecError::Truncated);
332        }
333        Ok(())
334    }
335
336    fn take_array<const N: usize>(&mut self) -> Result<[u8; N], CodecError> {
337        self.require(N)?;
338        let mut out = [0u8; N];
339        out.copy_from_slice(&self.buf[self.pos..self.pos + N]);
340        self.pos += N;
341        Ok(out)
342    }
343
344    fn take_u8(&mut self) -> Result<u8, CodecError> {
345        self.require(1)?;
346        let v = self.buf[self.pos];
347        self.pos += 1;
348        Ok(v)
349    }
350
351    fn take_u16(&mut self) -> Result<u16, CodecError> {
352        let bytes = self.take_array::<2>()?;
353        Ok(u16::from_le_bytes(bytes))
354    }
355
356    fn take_u32(&mut self) -> Result<u32, CodecError> {
357        let bytes = self.take_array::<4>()?;
358        Ok(u32::from_le_bytes(bytes))
359    }
360
361    fn take_f32(&mut self) -> Result<f32, CodecError> {
362        let bytes = self.take_array::<4>()?;
363        Ok(f32::from_le_bytes(bytes))
364    }
365
366    fn take_bytes(&mut self) -> Result<&'a [u8], CodecError> {
367        let len = self.take_u32()? as usize;
368        self.require(len)?;
369        let out = &self.buf[self.pos..self.pos + len];
370        self.pos += len;
371        Ok(out)
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    fn knn_request() -> BroadcastRequest {
380        BroadcastRequest {
381            table: "ix".into(),
382            query: SerializedQuery::Knn {
383                vector_field: "v".into(),
384                vector_bytes: vec![0x00, 0x01, 0x02, 0x03],
385                ef: Some(64),
386            },
387            top_k: 5,
388        }
389    }
390
391    #[test]
392    fn knn_round_trip() {
393        let req = knn_request();
394        let bytes = encode_request(&req);
395        let back = decode_request(&bytes).unwrap();
396        assert_eq!(req, back);
397    }
398
399    #[test]
400    fn knn_round_trip_no_ef() {
401        let mut req = knn_request();
402        if let SerializedQuery::Knn { ef, .. } = &mut req.query {
403            *ef = None;
404        }
405        let bytes = encode_request(&req);
406        let back = decode_request(&bytes).unwrap();
407        assert_eq!(req, back);
408    }
409
410    #[test]
411    fn text_round_trip() {
412        let req = BroadcastRequest {
413            table: "idx".into(),
414            query: SerializedQuery::Text {
415                field: "body".into(),
416                query: b"foo bar".to_vec(),
417            },
418            top_k: 3,
419        };
420        let bytes = encode_request(&req);
421        assert_eq!(decode_request(&bytes).unwrap(), req);
422    }
423
424    #[test]
425    fn regex_round_trip() {
426        let req = BroadcastRequest {
427            table: "idx".into(),
428            query: SerializedQuery::Regex {
429                field: "body".into(),
430                pattern: "ab.*c".into(),
431                max_errors: 2,
432            },
433            top_k: 7,
434        };
435        let bytes = encode_request(&req);
436        assert_eq!(decode_request(&bytes).unwrap(), req);
437    }
438
439    #[test]
440    fn reply_round_trip() {
441        let reply = PeerReply {
442            hits: vec![
443                HitWithScore {
444                    doc_id: b"a".to_vec(),
445                    score: 0.10,
446                },
447                HitWithScore {
448                    doc_id: b"longer:doc:id".to_vec(),
449                    score: 0.42,
450                },
451            ],
452            timed_out: false,
453        };
454        let bytes = encode_reply(&reply);
455        let back = decode_reply(&bytes).unwrap();
456        assert_eq!(reply, back);
457    }
458
459    #[test]
460    fn reply_with_timed_out_flag() {
461        let reply = PeerReply {
462            hits: Vec::new(),
463            timed_out: true,
464        };
465        let bytes = encode_reply(&reply);
466        let back = decode_reply(&bytes).unwrap();
467        assert!(back.timed_out);
468        assert!(back.hits.is_empty());
469    }
470
471    #[test]
472    fn reply_with_no_hits() {
473        let reply = PeerReply {
474            hits: Vec::new(),
475            timed_out: false,
476        };
477        let bytes = encode_reply(&reply);
478        let back = decode_reply(&bytes).unwrap();
479        assert_eq!(reply, back);
480    }
481
482    #[test]
483    fn truncated_request_rejected() {
484        let req = knn_request();
485        let bytes = encode_request(&req);
486        for n in 0..bytes.len() {
487            assert_eq!(decode_request(&bytes[..n]), Err(CodecError::Truncated));
488        }
489    }
490
491    #[test]
492    fn bad_magic_rejected() {
493        let bytes = vec![b'X'; 32];
494        assert_eq!(decode_request(&bytes).unwrap_err(), CodecError::BadMagic);
495        assert_eq!(decode_reply(&bytes).unwrap_err(), CodecError::BadMagic);
496    }
497
498    #[test]
499    fn bad_tag_rejected() {
500        let mut bytes = encode_request(&knn_request());
501        // Locate and overwrite the tag byte (right after table
502        // bytes). Re-derive its index from the layout: 4 magic +
503        // 2 flags + 4 top_k + 4 table_len + table_len bytes.
504        let table_len_offset = 4 + 2 + 4;
505        let table_len = u32::from_le_bytes(
506            bytes[table_len_offset..table_len_offset + 4]
507                .try_into()
508                .unwrap(),
509        ) as usize;
510        let tag_offset = table_len_offset + 4 + table_len;
511        bytes[tag_offset] = 0xff;
512        assert_eq!(
513            decode_request(&bytes).unwrap_err(),
514            CodecError::BadTag(0xff)
515        );
516    }
517
518    #[test]
519    fn non_zero_flags_rejected() {
520        let mut bytes = encode_request(&knn_request());
521        bytes[4] = 0x01;
522        assert_eq!(decode_request(&bytes).unwrap_err(), CodecError::BadFlags);
523    }
524}