1use std::convert::TryFrom;
65
66use thiserror::Error;
67
68use super::query_fsm::{BroadcastRequest, HitWithScore, PeerReply, SerializedQuery};
69
70pub const REQ_MAGIC: [u8; 4] = *b"FTQ1";
73
74pub const REP_MAGIC: [u8; 4] = *b"FTR1";
77
78const TAG_KNN: u8 = 0;
79const TAG_TEXT: u8 = 1;
80const TAG_REGEX: u8 = 2;
81
82#[derive(Debug, Error, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum CodecError {
86 #[error("FT search payload truncated")]
88 Truncated,
89 #[error("FT search payload bad magic")]
91 BadMagic,
92 #[error("FT search payload bad flags")]
94 BadFlags,
95 #[error("FT search field length out of range")]
97 LengthOverflow,
98 #[error("FT search field not utf-8")]
100 BadUtf8,
101 #[error("FT search unknown query tag {0}")]
103 BadTag(u8),
104}
105
106#[must_use]
129pub fn encode_request(req: &BroadcastRequest) -> Vec<u8> {
130 let mut out = Vec::with_capacity(64);
131 out.extend_from_slice(&REQ_MAGIC);
132 out.extend_from_slice(&0u16.to_le_bytes());
133 out.extend_from_slice(&req.top_k.to_le_bytes());
134 write_bytes(&mut out, req.table.as_bytes());
135 match &req.query {
136 SerializedQuery::Knn {
137 vector_field,
138 vector_bytes,
139 ef,
140 } => {
141 out.push(TAG_KNN);
142 write_bytes(&mut out, vector_field.as_bytes());
143 write_bytes(&mut out, vector_bytes);
144 match ef {
145 Some(value) => {
146 out.push(1);
147 out.extend_from_slice(&value.to_le_bytes());
148 }
149 None => out.push(0),
150 }
151 }
152 SerializedQuery::Text { field, query } => {
153 out.push(TAG_TEXT);
154 write_bytes(&mut out, field.as_bytes());
155 write_bytes(&mut out, query);
156 }
157 SerializedQuery::Regex {
158 field,
159 pattern,
160 max_errors,
161 } => {
162 out.push(TAG_REGEX);
163 write_bytes(&mut out, field.as_bytes());
164 write_bytes(&mut out, pattern.as_bytes());
165 out.extend_from_slice(&max_errors.to_le_bytes());
166 }
167 }
168 out
169}
170
171pub fn decode_request(bytes: &[u8]) -> Result<BroadcastRequest, CodecError> {
180 let mut cursor = Cursor::new(bytes);
181 let magic = cursor.take_array::<4>()?;
182 if magic != REQ_MAGIC {
183 return Err(CodecError::BadMagic);
184 }
185 let flags = cursor.take_u16()?;
186 if flags != 0 {
187 return Err(CodecError::BadFlags);
188 }
189 let top_k = cursor.take_u32()?;
190 let table_bytes = cursor.take_bytes()?.to_vec();
191 let table = String::from_utf8(table_bytes).map_err(|_| CodecError::BadUtf8)?;
192 let tag = cursor.take_u8()?;
193 let query = match tag {
194 TAG_KNN => {
195 let field_bytes = cursor.take_bytes()?.to_vec();
196 let vector_field = String::from_utf8(field_bytes).map_err(|_| CodecError::BadUtf8)?;
197 let vector_bytes = cursor.take_bytes()?.to_vec();
198 let ef_present = cursor.take_u8()?;
199 let ef = match ef_present {
200 0 => None,
201 1 => Some(cursor.take_u32()?),
202 _ => return Err(CodecError::BadFlags),
203 };
204 SerializedQuery::Knn {
205 vector_field,
206 vector_bytes,
207 ef,
208 }
209 }
210 TAG_TEXT => {
211 let field_bytes = cursor.take_bytes()?.to_vec();
212 let field = String::from_utf8(field_bytes).map_err(|_| CodecError::BadUtf8)?;
213 let query = cursor.take_bytes()?.to_vec();
214 SerializedQuery::Text { field, query }
215 }
216 TAG_REGEX => {
217 let field_bytes = cursor.take_bytes()?.to_vec();
218 let field = String::from_utf8(field_bytes).map_err(|_| CodecError::BadUtf8)?;
219 let pattern_bytes = cursor.take_bytes()?.to_vec();
220 let pattern = String::from_utf8(pattern_bytes).map_err(|_| CodecError::BadUtf8)?;
221 let max_errors = cursor.take_u16()?;
222 SerializedQuery::Regex {
223 field,
224 pattern,
225 max_errors,
226 }
227 }
228 other => return Err(CodecError::BadTag(other)),
229 };
230 Ok(BroadcastRequest {
231 table,
232 query,
233 top_k,
234 })
235}
236
237#[must_use]
258pub fn encode_reply(reply: &PeerReply) -> Vec<u8> {
259 let mut out = Vec::with_capacity(32 + reply.hits.len() * 24);
260 out.extend_from_slice(&REP_MAGIC);
261 out.extend_from_slice(&0u16.to_le_bytes());
262 out.push(u8::from(reply.timed_out));
263 let count = u32::try_from(reply.hits.len()).unwrap_or(u32::MAX);
264 out.extend_from_slice(&count.to_le_bytes());
265 let max = count as usize;
266 for hit in reply.hits.iter().take(max) {
267 write_bytes(&mut out, &hit.doc_id);
268 out.extend_from_slice(&hit.score.to_le_bytes());
269 }
270 out
271}
272
273pub fn decode_reply(bytes: &[u8]) -> Result<PeerReply, CodecError> {
281 let mut cursor = Cursor::new(bytes);
282 let magic = cursor.take_array::<4>()?;
283 if magic != REP_MAGIC {
284 return Err(CodecError::BadMagic);
285 }
286 let flags = cursor.take_u16()?;
287 if flags != 0 {
288 return Err(CodecError::BadFlags);
289 }
290 let timed_out_byte = cursor.take_u8()?;
291 if timed_out_byte > 1 {
292 return Err(CodecError::BadFlags);
293 }
294 let timed_out = timed_out_byte == 1;
295 let count = cursor.take_u32()?;
296 let count_usize = usize::try_from(count).map_err(|_| CodecError::LengthOverflow)?;
297 let mut hits: Vec<HitWithScore> = Vec::with_capacity(count_usize.min(64));
298 for _ in 0..count_usize {
299 let doc_id = cursor.take_bytes()?.to_vec();
300 let score = cursor.take_f32()?;
301 hits.push(HitWithScore { doc_id, score });
302 }
303 Ok(PeerReply { hits, timed_out })
304}
305
306fn write_bytes(out: &mut Vec<u8>, bytes: &[u8]) {
309 let len = u32::try_from(bytes.len()).unwrap_or(u32::MAX);
310 out.extend_from_slice(&len.to_le_bytes());
311 let max = len as usize;
312 out.extend_from_slice(&bytes[..bytes.len().min(max)]);
313}
314
315struct Cursor<'a> {
316 buf: &'a [u8],
317 pos: usize,
318}
319
320impl<'a> Cursor<'a> {
321 fn new(buf: &'a [u8]) -> Self {
322 Self { buf, pos: 0 }
323 }
324
325 fn require(&self, want: usize) -> Result<(), CodecError> {
326 if self
327 .pos
328 .checked_add(want)
329 .is_none_or(|end| end > self.buf.len())
330 {
331 return Err(CodecError::Truncated);
332 }
333 Ok(())
334 }
335
336 fn take_array<const N: usize>(&mut self) -> Result<[u8; N], CodecError> {
337 self.require(N)?;
338 let mut out = [0u8; N];
339 out.copy_from_slice(&self.buf[self.pos..self.pos + N]);
340 self.pos += N;
341 Ok(out)
342 }
343
344 fn take_u8(&mut self) -> Result<u8, CodecError> {
345 self.require(1)?;
346 let v = self.buf[self.pos];
347 self.pos += 1;
348 Ok(v)
349 }
350
351 fn take_u16(&mut self) -> Result<u16, CodecError> {
352 let bytes = self.take_array::<2>()?;
353 Ok(u16::from_le_bytes(bytes))
354 }
355
356 fn take_u32(&mut self) -> Result<u32, CodecError> {
357 let bytes = self.take_array::<4>()?;
358 Ok(u32::from_le_bytes(bytes))
359 }
360
361 fn take_f32(&mut self) -> Result<f32, CodecError> {
362 let bytes = self.take_array::<4>()?;
363 Ok(f32::from_le_bytes(bytes))
364 }
365
366 fn take_bytes(&mut self) -> Result<&'a [u8], CodecError> {
367 let len = self.take_u32()? as usize;
368 self.require(len)?;
369 let out = &self.buf[self.pos..self.pos + len];
370 self.pos += len;
371 Ok(out)
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 fn knn_request() -> BroadcastRequest {
380 BroadcastRequest {
381 table: "ix".into(),
382 query: SerializedQuery::Knn {
383 vector_field: "v".into(),
384 vector_bytes: vec![0x00, 0x01, 0x02, 0x03],
385 ef: Some(64),
386 },
387 top_k: 5,
388 }
389 }
390
391 #[test]
392 fn knn_round_trip() {
393 let req = knn_request();
394 let bytes = encode_request(&req);
395 let back = decode_request(&bytes).unwrap();
396 assert_eq!(req, back);
397 }
398
399 #[test]
400 fn knn_round_trip_no_ef() {
401 let mut req = knn_request();
402 if let SerializedQuery::Knn { ef, .. } = &mut req.query {
403 *ef = None;
404 }
405 let bytes = encode_request(&req);
406 let back = decode_request(&bytes).unwrap();
407 assert_eq!(req, back);
408 }
409
410 #[test]
411 fn text_round_trip() {
412 let req = BroadcastRequest {
413 table: "idx".into(),
414 query: SerializedQuery::Text {
415 field: "body".into(),
416 query: b"foo bar".to_vec(),
417 },
418 top_k: 3,
419 };
420 let bytes = encode_request(&req);
421 assert_eq!(decode_request(&bytes).unwrap(), req);
422 }
423
424 #[test]
425 fn regex_round_trip() {
426 let req = BroadcastRequest {
427 table: "idx".into(),
428 query: SerializedQuery::Regex {
429 field: "body".into(),
430 pattern: "ab.*c".into(),
431 max_errors: 2,
432 },
433 top_k: 7,
434 };
435 let bytes = encode_request(&req);
436 assert_eq!(decode_request(&bytes).unwrap(), req);
437 }
438
439 #[test]
440 fn reply_round_trip() {
441 let reply = PeerReply {
442 hits: vec![
443 HitWithScore {
444 doc_id: b"a".to_vec(),
445 score: 0.10,
446 },
447 HitWithScore {
448 doc_id: b"longer:doc:id".to_vec(),
449 score: 0.42,
450 },
451 ],
452 timed_out: false,
453 };
454 let bytes = encode_reply(&reply);
455 let back = decode_reply(&bytes).unwrap();
456 assert_eq!(reply, back);
457 }
458
459 #[test]
460 fn reply_with_timed_out_flag() {
461 let reply = PeerReply {
462 hits: Vec::new(),
463 timed_out: true,
464 };
465 let bytes = encode_reply(&reply);
466 let back = decode_reply(&bytes).unwrap();
467 assert!(back.timed_out);
468 assert!(back.hits.is_empty());
469 }
470
471 #[test]
472 fn reply_with_no_hits() {
473 let reply = PeerReply {
474 hits: Vec::new(),
475 timed_out: false,
476 };
477 let bytes = encode_reply(&reply);
478 let back = decode_reply(&bytes).unwrap();
479 assert_eq!(reply, back);
480 }
481
482 #[test]
483 fn truncated_request_rejected() {
484 let req = knn_request();
485 let bytes = encode_request(&req);
486 for n in 0..bytes.len() {
487 assert_eq!(decode_request(&bytes[..n]), Err(CodecError::Truncated));
488 }
489 }
490
491 #[test]
492 fn bad_magic_rejected() {
493 let bytes = vec![b'X'; 32];
494 assert_eq!(decode_request(&bytes).unwrap_err(), CodecError::BadMagic);
495 assert_eq!(decode_reply(&bytes).unwrap_err(), CodecError::BadMagic);
496 }
497
498 #[test]
499 fn bad_tag_rejected() {
500 let mut bytes = encode_request(&knn_request());
501 let table_len_offset = 4 + 2 + 4;
505 let table_len = u32::from_le_bytes(
506 bytes[table_len_offset..table_len_offset + 4]
507 .try_into()
508 .unwrap(),
509 ) as usize;
510 let tag_offset = table_len_offset + 4 + table_len;
511 bytes[tag_offset] = 0xff;
512 assert_eq!(
513 decode_request(&bytes).unwrap_err(),
514 CodecError::BadTag(0xff)
515 );
516 }
517
518 #[test]
519 fn non_zero_flags_rejected() {
520 let mut bytes = encode_request(&knn_request());
521 bytes[4] = 0x01;
522 assert_eq!(decode_request(&bytes).unwrap_err(), CodecError::BadFlags);
523 }
524}