1pub mod filter;
59
60pub use filter::Filter;
61use redis::aio::ConnectionManager;
62use rig_core::{
63 Embed, OneOrMany,
64 embeddings::embedding::{Embedding, EmbeddingModel},
65 vector_store::{
66 InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
67 request::{Filter as CoreFilter, VectorSearchRequest},
68 },
69 wasm_compat::WasmBoxedFuture,
70};
71use serde::{Deserialize, Serialize};
72
73pub struct RedisVectorStore<M>
90where
91 M: EmbeddingModel,
92{
93 model: M,
94 connection_manager: ConnectionManager,
95 index_name: String,
96 vector_field: String,
97 key_prefix: Option<String>,
98 metadata_fields: Vec<String>,
99 distance_metric: DistanceMetric,
100}
101
102impl<M> RedisVectorStore<M>
103where
104 M: EmbeddingModel,
105{
106 pub async fn new(
120 model: M,
121 client: redis::Client,
122 index_name: String,
123 vector_field: String,
124 ) -> Result<Self, VectorStoreError> {
125 let connection_manager = ConnectionManager::new(client)
126 .await
127 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
128
129 Ok(Self {
130 model,
131 connection_manager,
132 index_name,
133 vector_field,
134 key_prefix: None,
135 metadata_fields: Vec::new(),
136 distance_metric: DistanceMetric::default(),
137 })
138 }
139
140 pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
146 self.distance_metric = metric;
147 self
148 }
149
150 pub fn with_key_prefix(mut self, prefix: String) -> Self {
156 self.key_prefix = Some(prefix);
157 self
158 }
159
160 pub fn with_metadata_fields(mut self, fields: Vec<String>) -> Self {
179 self.metadata_fields = filter_reserved_metadata_fields(fields, &self.vector_field);
180 self
181 }
182
183 pub async fn validate_index(&self) -> Result<(), VectorStoreError> {
193 let mut con = self.connection_manager.clone();
194 let info: redis::Value = redis::cmd("FT.INFO")
195 .arg(&self.index_name)
196 .query_async(&mut con)
197 .await
198 .map_err(|e| {
199 VectorStoreError::DatastoreError(
200 format!(
201 "index '{}' not found or FT.INFO failed: {e}",
202 self.index_name
203 )
204 .into(),
205 )
206 })?;
207
208 let mut tokens = Vec::new();
209 Self::flatten_tokens(&info, &mut tokens);
210
211 let expected = self.distance_metric.as_arg();
212 for (i, tok) in tokens.iter().enumerate() {
213 if tok.eq_ignore_ascii_case("distance_metric") {
214 match tokens.get(i + 1) {
215 Some(m) if m.eq_ignore_ascii_case(expected) => {}
216 other => {
217 return Err(VectorStoreError::DatastoreError(
218 format!(
219 "index '{}' uses distance metric {:?}, but this store is configured for {}",
220 self.index_name, other, expected
221 )
222 .into(),
223 ));
224 }
225 }
226 }
227 }
228
229 if let Some(prefix) = &self.key_prefix {
230 const STOP: &[&str] = &[
231 "default_score",
232 "filter",
233 "language",
234 "language_field",
235 "score_field",
236 "payload_field",
237 "attributes",
238 ];
239 let found = tokens
240 .iter()
241 .position(|t| t == "prefixes")
242 .map(|p| {
243 tokens[p + 1..]
244 .iter()
245 .take_while(|t| !STOP.contains(&t.as_str()))
246 .any(|t| t == prefix)
247 })
248 .unwrap_or(false);
249 if !found {
250 return Err(VectorStoreError::DatastoreError(
251 format!(
252 "index '{}' is not configured with key prefix '{}'",
253 self.index_name, prefix
254 )
255 .into(),
256 ));
257 }
258 }
259
260 Ok(())
261 }
262
263 pub async fn create_index(
270 &self,
271 dimensions: usize,
272 metadata_fields: &[(String, MetadataFieldType)],
273 ) -> Result<(), VectorStoreError> {
274 let mut con = self.connection_manager.clone();
275 let mut cmd = redis::cmd("FT.CREATE");
276 cmd.arg(&self.index_name).arg("ON").arg("HASH");
277 if let Some(prefix) = &self.key_prefix {
278 cmd.arg("PREFIX").arg(1).arg(prefix);
279 }
280 cmd.arg("SCHEMA")
281 .arg("document")
282 .arg("TEXT")
283 .arg("embedded_text")
284 .arg("TEXT")
285 .arg(&self.vector_field)
286 .arg("VECTOR")
287 .arg("FLAT")
288 .arg(6)
289 .arg("TYPE")
290 .arg("FLOAT32")
291 .arg("DIM")
292 .arg(dimensions)
293 .arg("DISTANCE_METRIC")
294 .arg(self.distance_metric.as_arg());
295 for (name, ty) in metadata_fields {
296 cmd.arg(name).arg(ty.as_arg());
297 }
298 cmd.query_async::<()>(&mut con)
299 .await
300 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
301 }
302
303 pub async fn delete(&self, ids: &[String]) -> Result<u64, VectorStoreError> {
307 if ids.is_empty() {
308 return Ok(0);
309 }
310 let mut con = self.connection_manager.clone();
311 let mut cmd = redis::cmd("UNLINK");
312 for id in ids {
313 cmd.arg(id);
314 }
315 cmd.query_async::<u64>(&mut con)
316 .await
317 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
318 }
319
320 async fn embed_query(&self, query: &str) -> Result<Vec<u8>, VectorStoreError> {
322 let embedding = self.model.embed_text(query).await?;
323 if embedding.vec.iter().any(|x| !x.is_finite()) {
324 return Err(VectorStoreError::DatastoreError(
325 "query embedding contains non-finite (NaN/Inf) values".into(),
326 ));
327 }
328 Ok(Self::embedding_to_bytes(&embedding.vec))
329 }
330
331 fn embedding_to_bytes(embedding: &[f64]) -> Vec<u8> {
333 embedding
334 .iter()
335 .flat_map(|&x| (x as f32).to_le_bytes())
336 .collect()
337 }
338
339 fn extract_string(value: &redis::Value) -> Option<String> {
341 match value {
342 redis::Value::BulkString(bytes) => Some(String::from_utf8_lossy(bytes).to_string()),
343 redis::Value::SimpleString(s) => Some(s.clone()),
344 redis::Value::VerbatimString { text, .. } => Some(text.clone()),
345 _ => None,
346 }
347 }
348
349 fn extract_distance(value: &redis::Value) -> Result<f64, VectorStoreError> {
354 let distance = match value {
355 redis::Value::Double(d) => *d,
356 redis::Value::BulkString(bytes) => {
357 String::from_utf8_lossy(bytes).parse::<f64>().map_err(|e| {
358 VectorStoreError::DatastoreError(format!("Failed to parse score: {e}").into())
359 })?
360 }
361 redis::Value::SimpleString(s) | redis::Value::VerbatimString { text: s, .. } => {
362 s.parse::<f64>().map_err(|e| {
363 VectorStoreError::DatastoreError(format!("Failed to parse score: {e}").into())
364 })?
365 }
366 other => {
367 return Err(VectorStoreError::DatastoreError(
368 format!("Unexpected Redis value type for score: {other:?}").into(),
369 ));
370 }
371 };
372 Ok(distance)
373 }
374
375 fn parse_search_response<T>(
380 response: redis::Value,
381 ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
382 where
383 T: for<'a> Deserialize<'a>,
384 {
385 Self::parse_response_generic(response, true).map(|items| {
386 items
387 .into_iter()
388 .filter_map(|(score, id, doc_json)| {
389 if doc_json.is_empty() {
390 tracing::warn!(
391 target: "rig",
392 id = %id,
393 "Document field missing or empty in hash, skipping"
394 );
395 return None;
396 }
397 match serde_json::from_str::<T>(&doc_json) {
398 Ok(doc) => Some((score, id, doc)),
399 Err(e) => {
400 tracing::warn!(
401 target: "rig",
402 id = %id,
403 error = %e,
404 "Failed to deserialize document, skipping"
405 );
406 None
407 }
408 }
409 })
410 .collect()
411 })
412 }
413
414 fn parse_search_response_ids(
416 response: redis::Value,
417 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
418 Self::parse_response_generic(response, false).map(|items| {
419 items
420 .into_iter()
421 .map(|(score, id, _)| (score, id))
422 .collect()
423 })
424 }
425
426 fn parse_response_generic(
429 response: redis::Value,
430 include_document: bool,
431 ) -> Result<Vec<(f64, String, String)>, VectorStoreError> {
432 match response {
433 redis::Value::Map(pairs) => Self::parse_resp3_map(&pairs, include_document),
435 redis::Value::Array(items) => Self::parse_resp2_array(&items, include_document),
437 _ => Err(VectorStoreError::DatastoreError(
438 "Invalid FT.SEARCH response format (expected a RESP2 array or RESP3 map)".into(),
439 )),
440 }
441 }
442
443 fn parse_resp2_array(
445 items: &[redis::Value],
446 include_document: bool,
447 ) -> Result<Vec<(f64, String, String)>, VectorStoreError> {
448 let count = match items.first() {
449 Some(redis::Value::Int(n)) => *n as usize,
450 _ => {
451 return Err(VectorStoreError::DatastoreError(
452 "Invalid response format: expected count as first element".into(),
453 ));
454 }
455 };
456
457 if count == 0 {
458 return Ok(Vec::new());
459 }
460
461 let mut results = Vec::with_capacity(count);
462
463 let mut iter = items.iter().skip(1);
464 while let Some(key_val) = iter.next() {
465 let id = match Self::extract_string(key_val) {
466 Some(id) => id,
467 None => {
468 iter.next();
469 continue;
470 }
471 };
472
473 let fields_val = match iter.next() {
474 Some(redis::Value::Array(fields)) => fields,
475 _ => continue,
476 };
477
478 let mut distance = 0.0;
479 let mut score_found = false;
480 let mut document_json = String::new();
481
482 for chunk in fields_val.chunks(2) {
483 let [name_val, value_val] = chunk else {
484 continue;
485 };
486 let field_name = match Self::extract_string(name_val) {
487 Some(name) => name,
488 None => continue,
489 };
490
491 if field_name == "__vector_score" {
492 distance = Self::extract_distance(value_val)?;
493 score_found = true;
494 } else if include_document && field_name == "document" {
495 match Self::extract_string(value_val) {
496 Some(json) => document_json = json,
497 None => {
498 tracing::warn!(
499 target: "rig",
500 id = %id,
501 "Document field present but could not be extracted as string"
502 );
503 }
504 }
505 }
506 }
507
508 if !score_found {
509 tracing::warn!(
510 target: "rig",
511 id = %id,
512 "__vector_score field missing from search result, defaulting to 0.0"
513 );
514 }
515
516 results.push((distance, id, document_json));
517 }
518
519 Ok(results)
520 }
521
522 fn parse_resp3_map(
524 pairs: &[(redis::Value, redis::Value)],
525 include_document: bool,
526 ) -> Result<Vec<(f64, String, String)>, VectorStoreError> {
527 let entries = pairs
528 .iter()
529 .find_map(|(k, v)| match (Self::extract_string(k), v) {
530 (Some(name), redis::Value::Array(items)) if name == "results" => Some(items),
531 _ => None,
532 });
533
534 let Some(entries) = entries else {
535 return Ok(Vec::new());
537 };
538
539 let mut results = Vec::with_capacity(entries.len());
540 for entry in entries {
541 let redis::Value::Map(fields) = entry else {
542 continue;
543 };
544
545 let mut id = String::new();
546 let mut distance = 0.0;
547 let mut score_found = false;
548 let mut document_json = String::new();
549
550 for (k, v) in fields {
551 match Self::extract_string(k).as_deref() {
552 Some("id") => {
553 if let Some(s) = Self::extract_string(v) {
554 id = s;
555 }
556 }
557 Some("extra_attributes") => {
558 if let redis::Value::Map(attrs) = v {
559 for (ak, av) in attrs {
560 match Self::extract_string(ak).as_deref() {
561 Some("__vector_score") => {
562 distance = Self::extract_distance(av)?;
563 score_found = true;
564 }
565 Some("document") if include_document => {
566 if let Some(s) = Self::extract_string(av) {
567 document_json = s;
568 }
569 }
570 _ => {}
571 }
572 }
573 }
574 }
575 _ => {}
576 }
577 }
578
579 if !score_found {
580 tracing::warn!(
581 target: "rig",
582 id = %id,
583 "__vector_score field missing from search result, defaulting to 0.0"
584 );
585 }
586
587 results.push((distance, id, document_json));
588 }
589
590 Ok(results)
591 }
592
593 fn flatten_tokens(value: &redis::Value, out: &mut Vec<String>) {
596 match value {
597 redis::Value::Array(items) | redis::Value::Set(items) => {
598 for v in items {
599 Self::flatten_tokens(v, out);
600 }
601 }
602 redis::Value::Map(pairs) => {
603 for (k, v) in pairs {
604 Self::flatten_tokens(k, out);
605 Self::flatten_tokens(v, out);
606 }
607 }
608 redis::Value::BulkString(bytes) => out.push(String::from_utf8_lossy(bytes).to_string()),
609 redis::Value::SimpleString(s) => out.push(s.clone()),
610 redis::Value::VerbatimString { text, .. } => out.push(text.clone()),
611 redis::Value::Int(i) => out.push(i.to_string()),
612 redis::Value::Double(d) => out.push(d.to_string()),
613 _ => {}
614 }
615 }
616
617 async fn execute_search(
619 &self,
620 vector_bytes: Vec<u8>,
621 req: &VectorSearchRequest<Filter>,
622 include_document: bool,
623 ) -> Result<redis::Value, VectorStoreError> {
624 let mut con = self.connection_manager.clone();
625
626 let filter_str = req
627 .filter()
628 .as_ref()
629 .map(|f| f.clone().into_inner())
630 .unwrap_or_else(|| "*".to_string());
631
632 let knn_query = format!(
633 "{}=>[KNN {} @{} $vec AS __vector_score]",
634 filter_str,
635 req.samples(),
636 self.vector_field
637 );
638
639 let mut cmd = redis::cmd("FT.SEARCH");
640 cmd.arg(&self.index_name)
641 .arg(&knn_query)
642 .arg("PARAMS")
643 .arg(2)
644 .arg("vec")
645 .arg(vector_bytes)
646 .arg("SORTBY")
647 .arg("__vector_score")
648 .arg("RETURN");
649
650 if include_document {
651 cmd.arg(2).arg("__vector_score").arg("document");
652 } else {
653 cmd.arg(1).arg("__vector_score");
654 }
655
656 cmd.arg("DIALECT").arg(2);
657
658 cmd.arg("LIMIT").arg(0).arg(req.samples());
660
661 cmd.query_async(&mut con)
662 .await
663 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
664 }
665
666 fn json_value_to_hash_field(value: &serde_json::Value) -> Option<String> {
671 match value {
672 serde_json::Value::String(s) => Some(s.clone()),
673 serde_json::Value::Number(n) => Some(n.to_string()),
674 serde_json::Value::Bool(b) => Some(if *b { "1".to_string() } else { "0".to_string() }),
675 serde_json::Value::Null
676 | serde_json::Value::Array(_)
677 | serde_json::Value::Object(_) => None,
678 }
679 }
680}
681
682impl<Model> InsertDocuments for RedisVectorStore<Model>
683where
684 Model: EmbeddingModel + Send + Sync,
685{
686 async fn insert_documents<Doc: Serialize + Embed + Send>(
692 &self,
693 documents: Vec<(Doc, OneOrMany<Embedding>)>,
694 ) -> Result<(), VectorStoreError> {
695 let mut con = self.connection_manager.clone();
696 let mut pipe = redis::pipe();
697
698 for (document, embeddings) in &documents {
699 let json_value = serde_json::to_value(document)?;
700 let json_document = json_value.to_string();
701
702 let metadata: Vec<(String, String)> = if self.metadata_fields.is_empty() {
704 Vec::new()
705 } else {
706 self.metadata_fields
707 .iter()
708 .filter_map(|field_name| {
709 let value = json_value.get(field_name)?;
710 match Self::json_value_to_hash_field(value) {
711 Some(hash_value) => Some((field_name.clone(), hash_value)),
712 None => {
713 tracing::warn!(
714 target: "rig",
715 field = %field_name,
716 value_type = %value,
717 "Metadata field has unsupported type (null/array/object), skipping"
718 );
719 None
720 }
721 }
722 })
723 .collect()
724 };
725
726 for embedding in embeddings.iter() {
727 let id = if let Some(ref prefix) = self.key_prefix {
728 format!("{}{}", prefix, uuid::Uuid::new_v4())
729 } else {
730 uuid::Uuid::new_v4().to_string()
731 };
732 let embedding_bytes = Self::embedding_to_bytes(&embedding.vec);
733
734 let cmd = pipe
735 .cmd("HSET")
736 .arg(&id)
737 .arg("document")
738 .arg(json_document.as_bytes())
739 .arg("embedded_text")
740 .arg(embedding.document.as_bytes())
741 .arg(&self.vector_field)
742 .arg(embedding_bytes);
743
744 for (field_name, field_value) in &metadata {
745 cmd.arg(field_name).arg(field_value.as_bytes());
746 }
747
748 cmd.ignore();
749 }
750 }
751
752 pipe.query_async::<()>(&mut con)
753 .await
754 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
755
756 tracing::debug!(
757 target: "rig",
758 index = %self.index_name,
759 count = documents.len(),
760 metadata_fields = ?self.metadata_fields,
761 "Inserted documents into Redis vector store"
762 );
763
764 Ok(())
765 }
766}
767
768impl<M> VectorStoreIndex for RedisVectorStore<M>
769where
770 M: EmbeddingModel + Send + Sync,
771{
772 type Filter = Filter;
773
774 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
775 &self,
776 req: VectorSearchRequest<Self::Filter>,
777 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
778 if req.samples() == 0 {
779 return Ok(Vec::new());
780 }
781 let vector_bytes = self.embed_query(req.query()).await?;
782
783 let response = self.execute_search(vector_bytes, &req, true).await?;
784 let mut results = Self::parse_search_response::<T>(response)?
785 .into_iter()
786 .map(|(distance, id, doc)| (self.distance_metric.score(distance), id, doc))
787 .collect::<Vec<_>>();
788
789 if let Some(threshold) = req.threshold() {
790 results.retain(|(score, _, _)| *score >= threshold);
791 }
792
793 tracing::debug!(
794 target: "rig",
795 index = %self.index_name,
796 query = %req.query(),
797 "Selected documents: {}",
798 results.iter().map(|(score, id, _)| format!("{id} ({score:.4})")).collect::<Vec<_>>().join(", ")
799 );
800
801 Ok(results)
802 }
803
804 async fn top_n_ids(
805 &self,
806 req: VectorSearchRequest<Self::Filter>,
807 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
808 if req.samples() == 0 {
809 return Ok(Vec::new());
810 }
811 let vector_bytes = self.embed_query(req.query()).await?;
812
813 let response = self.execute_search(vector_bytes, &req, false).await?;
814 let mut results = Self::parse_search_response_ids(response)?
815 .into_iter()
816 .map(|(distance, id)| (self.distance_metric.score(distance), id))
817 .collect::<Vec<_>>();
818
819 if let Some(threshold) = req.threshold() {
820 results.retain(|(score, _)| *score >= threshold);
821 }
822
823 tracing::debug!(
824 target: "rig",
825 index = %self.index_name,
826 query = %req.query(),
827 "Selected document IDs: {}",
828 results.iter().map(|(score, id)| format!("{id} ({score:.4})")).collect::<Vec<_>>().join(", ")
829 );
830
831 Ok(results)
832 }
833}
834
835impl<M> VectorStoreIndexDyn for RedisVectorStore<M>
836where
837 M: EmbeddingModel + Sync + Send,
838{
839 fn top_n<'a>(
840 &'a self,
841 req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
842 ) -> WasmBoxedFuture<'a, TopNResults> {
843 Box::pin(async move {
844 let req = req.try_map_filter(Filter::try_from)?;
845 let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
846 Ok(results)
847 })
848 }
849
850 fn top_n_ids<'a>(
851 &'a self,
852 req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
853 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
854 Box::pin(async move {
855 let req = req.try_map_filter(Filter::try_from)?;
856 let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
857 Ok(results)
858 })
859 }
860}
861
862fn filter_reserved_metadata_fields(fields: Vec<String>, vector_field: &str) -> Vec<String> {
866 let reserved = ["document", "embedded_text", vector_field];
867 fields
868 .into_iter()
869 .filter(|f| {
870 if reserved.contains(&f.as_str()) {
871 tracing::warn!(
872 target: "rig",
873 field = %f,
874 "Metadata field name conflicts with reserved hash field, skipping"
875 );
876 false
877 } else {
878 true
879 }
880 })
881 .collect()
882}
883
884#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
887pub enum DistanceMetric {
888 #[default]
890 Cosine,
891 L2,
894 InnerProduct,
897}
898
899impl DistanceMetric {
900 fn as_arg(self) -> &'static str {
902 match self {
903 DistanceMetric::Cosine => "COSINE",
904 DistanceMetric::L2 => "L2",
905 DistanceMetric::InnerProduct => "IP",
906 }
907 }
908
909 fn score(self, distance: f64) -> f64 {
913 match self {
914 DistanceMetric::Cosine | DistanceMetric::InnerProduct => 1.0 - distance,
915 DistanceMetric::L2 => 1.0 / (1.0 + distance),
916 }
917 }
918}
919
920#[derive(Debug, Clone, Copy, PartialEq, Eq)]
922pub enum MetadataFieldType {
923 Tag,
925 Numeric,
927 Text,
929}
930
931impl MetadataFieldType {
932 fn as_arg(self) -> &'static str {
933 match self {
934 MetadataFieldType::Tag => "TAG",
935 MetadataFieldType::Numeric => "NUMERIC",
936 MetadataFieldType::Text => "TEXT",
937 }
938 }
939}
940
941#[cfg(test)]
942mod tests {
943 use super::*;
944 use rig_core::embeddings::embedding::EmbeddingError;
945
946 struct FakeModel;
949
950 impl EmbeddingModel for FakeModel {
951 const MAX_DOCUMENTS: usize = 1024;
952 type Client = ();
953
954 fn make(_client: &Self::Client, _model: impl Into<String>, _dims: Option<usize>) -> Self {
955 FakeModel
956 }
957
958 fn ndims(&self) -> usize {
959 3
960 }
961
962 async fn embed_texts(
963 &self,
964 _texts: impl IntoIterator<Item = String> + Send,
965 ) -> Result<Vec<Embedding>, EmbeddingError> {
966 Ok(Vec::new())
967 }
968 }
969
970 type Store = RedisVectorStore<FakeModel>;
971
972 fn bulk(s: &str) -> redis::Value {
973 redis::Value::BulkString(s.as_bytes().to_vec())
974 }
975
976 #[test]
977 fn reserved_metadata_fields_are_filtered() {
978 let kept = filter_reserved_metadata_fields(
979 vec![
980 "category".to_string(),
981 "document".to_string(),
982 "embedded_text".to_string(),
983 "embedding".to_string(),
984 "price".to_string(),
985 ],
986 "embedding",
987 );
988 assert_eq!(kept, vec!["category".to_string(), "price".to_string()]);
989 }
990
991 #[test]
992 fn json_value_to_hash_field_covers_all_types() {
993 assert_eq!(
994 Store::json_value_to_hash_field(&serde_json::json!("hello")),
995 Some("hello".to_string())
996 );
997 assert_eq!(
998 Store::json_value_to_hash_field(&serde_json::json!(3)),
999 Some("3".to_string())
1000 );
1001 assert_eq!(
1002 Store::json_value_to_hash_field(&serde_json::json!(true)),
1003 Some("1".to_string())
1004 );
1005 assert_eq!(
1006 Store::json_value_to_hash_field(&serde_json::json!(false)),
1007 Some("0".to_string())
1008 );
1009 assert_eq!(
1010 Store::json_value_to_hash_field(&serde_json::Value::Null),
1011 None
1012 );
1013 assert_eq!(
1014 Store::json_value_to_hash_field(&serde_json::json!([1, 2])),
1015 None
1016 );
1017 assert_eq!(
1018 Store::json_value_to_hash_field(&serde_json::json!({"a": 1})),
1019 None
1020 );
1021 }
1022
1023 #[test]
1024 fn embedding_to_bytes_is_float32_le() {
1025 let bytes = Store::embedding_to_bytes(&[1.0_f64]);
1026 assert_eq!(bytes, vec![0, 0, 128, 63]); }
1028
1029 #[test]
1030 fn parse_search_response_skips_empty_documents() {
1031 let response = redis::Value::Array(vec![
1033 redis::Value::Int(2),
1034 bulk("doc:1"),
1035 redis::Value::Array(vec![
1036 bulk("__vector_score"),
1037 bulk("0.1"),
1038 bulk("document"),
1039 bulk("{\"a\":1}"),
1040 ]),
1041 bulk("doc:2"),
1042 redis::Value::Array(vec![
1043 bulk("__vector_score"),
1044 bulk("0.2"),
1045 bulk("document"),
1046 bulk(""),
1047 ]),
1048 ]);
1049
1050 let results =
1051 Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1052 assert_eq!(results.len(), 1);
1053 assert_eq!(results[0].1, "doc:1");
1054 assert!((results[0].0 - 0.1).abs() < 1e-9); }
1056
1057 #[test]
1058 fn parse_search_response_empty_when_count_zero() {
1059 let response = redis::Value::Array(vec![redis::Value::Int(0)]);
1060 let results =
1061 Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1062 assert!(results.is_empty());
1063 }
1064
1065 #[test]
1066 fn parse_resp3_map_response() {
1067 let response = redis::Value::Map(vec![
1069 (bulk("attributes"), redis::Value::Array(vec![])),
1070 (bulk("format"), bulk("STRING")),
1071 (
1072 bulk("results"),
1073 redis::Value::Array(vec![redis::Value::Map(vec![
1074 (bulk("id"), bulk("d:1")),
1075 (
1076 bulk("extra_attributes"),
1077 redis::Value::Map(vec![
1078 (bulk("__vector_score"), bulk("0.1")),
1079 (bulk("document"), bulk("{\"a\":1}")),
1080 ]),
1081 ),
1082 ])]),
1083 ),
1084 (bulk("total_results"), redis::Value::Int(1)),
1085 ]);
1086
1087 let results =
1088 Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1089 assert_eq!(results.len(), 1);
1090 assert_eq!(results[0].1, "d:1");
1091 assert!((results[0].0 - 0.1).abs() < 1e-9); }
1093
1094 #[test]
1095 fn parse_resp3_map_empty_results() {
1096 let response = redis::Value::Map(vec![
1097 (bulk("results"), redis::Value::Array(vec![])),
1098 (bulk("total_results"), redis::Value::Int(0)),
1099 ]);
1100 let results =
1101 Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1102 assert!(results.is_empty());
1103 }
1104
1105 #[test]
1106 fn distance_metric_score_conversions() {
1107 assert!((DistanceMetric::Cosine.score(0.0) - 1.0).abs() < 1e-9);
1109 assert!((DistanceMetric::Cosine.score(2.0) - (-1.0)).abs() < 1e-9);
1110 assert!((DistanceMetric::InnerProduct.score(0.0) - 1.0).abs() < 1e-9);
1112 assert!((DistanceMetric::InnerProduct.score(0.5) - 0.5).abs() < 1e-9);
1113 assert!((DistanceMetric::L2.score(0.0) - 1.0).abs() < 1e-9);
1115 assert!((DistanceMetric::L2.score(3.0) - 0.25).abs() < 1e-9);
1116 }
1117
1118 #[test]
1119 fn distance_metric_score_is_monotonic_decreasing() {
1120 for metric in [
1121 DistanceMetric::Cosine,
1122 DistanceMetric::L2,
1123 DistanceMetric::InnerProduct,
1124 ] {
1125 assert!(
1126 metric.score(0.1) > metric.score(0.5),
1127 "{metric:?} score must decrease as distance grows"
1128 );
1129 }
1130 }
1131
1132 #[test]
1133 fn distance_metric_as_arg() {
1134 assert_eq!(DistanceMetric::Cosine.as_arg(), "COSINE");
1135 assert_eq!(DistanceMetric::L2.as_arg(), "L2");
1136 assert_eq!(DistanceMetric::InnerProduct.as_arg(), "IP");
1137 }
1138}