Skip to main content

rig_redis_vectorstore/
lib.rs

1//! Redis vector store integration for Rig.
2//!
3//! Provides a [`RedisVectorStore`] that implements Rig's [`VectorStoreIndex`] and
4//! [`InsertDocuments`] traits using RediSearch's vector similarity search (`FT.SEARCH`).
5//!
6//! # Prerequisites
7//!
8//! The RediSearch index must be created before using this store. The expected schema is:
9//! - A HASH-based index with the specified prefix
10//! - A `document` field of type TEXT (stores serialized JSON)
11//! - An `embedded_text` field of type TEXT (stores the source text)
12//! - A vector field (configurable name) of type VECTOR with FLOAT32 elements
13//! - Optionally, additional fields for metadata filtering (TAG, NUMERIC, etc.)
14//!
15//! # Distance Metric
16//!
17//! The metric is configurable via [`RedisVectorStore::with_distance_metric`] and must
18//! match the index's `DISTANCE_METRIC`. [`DistanceMetric::Cosine`] (the default),
19//! [`DistanceMetric::L2`], and [`DistanceMetric::InnerProduct`] are supported. Returned
20//! distances are converted to similarity scores (higher = more similar) per metric; see
21//! [`DistanceMetric`]. Use [`RedisVectorStore::validate_index`] to confirm the index
22//! agrees with the configured metric.
23//!
24//! # Metadata Filtering
25//!
26//! To enable filtering on document fields during search, configure metadata fields
27//! via [`RedisVectorStore::with_metadata_fields`]. These fields are extracted from
28//! the serialized document JSON during insertion and written as separate hash fields,
29//! making them available for RediSearch filter queries. Your index schema must declare
30//! these fields with appropriate types (TAG, NUMERIC, TEXT) for filters to work.
31//!
32//! # Limitations
33//!
34//! - **Single-node only.** Inserts are pipelined across multiple keys, which is not
35//!   compatible with Redis Cluster (CROSSSLOT). Cluster support is a planned follow-up.
36//! - **Key prefix must match the index `PREFIX`**, otherwise inserted documents are
37//!   stored but never indexed.
38//! - **Multiple embeddings per document** produce multiple independently searchable
39//!   hashes, so a single logical document may appear more than once in results.
40//!
41//! Both RESP2 and RESP3 `FT.SEARCH` reply shapes are parsed.
42//!
43//! # Example
44//! ```ignore
45//! use rig_redis_vectorstore::RedisVectorStore;
46//!
47//! let store = RedisVectorStore::new(
48//!     embedding_model,
49//!     redis_client,
50//!     "my_index".into(),
51//!     "embedding".into(),
52//! )
53//! .await?
54//! .with_key_prefix("doc:".to_string())
55//! .with_metadata_fields(vec!["category".to_string(), "price".to_string()]);
56//! ```
57
58pub mod filter;
59
60pub use filter::Filter;
61use redis::aio::ConnectionManager;
62use rig_core::{
63    Embed, OneOrMany,
64    embeddings::embedding::{Embedding, EmbeddingModel},
65    vector_store::{
66        InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
67        request::{Filter as CoreFilter, VectorSearchRequest},
68    },
69    wasm_compat::WasmBoxedFuture,
70};
71use serde::{Deserialize, Serialize};
72
73/// Redis vector store implementation using RediSearch vector similarity search.
74///
75/// Uses Redis's `FT.SEARCH` command with KNN vector queries for similarity search.
76/// Internally holds a [`ConnectionManager`] for automatic reconnection on transient failures.
77///
78/// # Key Prefix
79///
80/// If your RediSearch index uses a `PREFIX` configuration (e.g., `PREFIX 1 doc:`),
81/// you **must** call [`RedisVectorStore::with_key_prefix`] with the matching prefix
82/// so that inserted documents are discoverable by the index.
83///
84/// # Metadata Fields
85///
86/// Configure metadata fields via [`RedisVectorStore::with_metadata_fields`] to enable
87/// filtering. During insertion, these fields are extracted from the serialized document
88/// and stored as separate hash fields that RediSearch can index and filter on.
89pub struct RedisVectorStore<M>
90where
91    M: EmbeddingModel,
92{
93    model: M,
94    connection_manager: ConnectionManager,
95    index_name: String,
96    vector_field: String,
97    key_prefix: Option<String>,
98    metadata_fields: Vec<String>,
99    distance_metric: DistanceMetric,
100}
101
102impl<M> RedisVectorStore<M>
103where
104    M: EmbeddingModel,
105{
106    /// Creates a new Redis vector store instance.
107    ///
108    /// Establishes a [`ConnectionManager`] from the provided client for automatic
109    /// reconnection on transient network failures.
110    ///
111    /// # Arguments
112    /// * `model` - Embedding model for query vectorization
113    /// * `client` - Redis client instance
114    /// * `index_name` - Name of the RediSearch index to query
115    /// * `vector_field` - Name of the vector field in the index
116    ///
117    /// # Errors
118    /// Returns an error if the initial connection to Redis cannot be established.
119    pub async fn new(
120        model: M,
121        client: redis::Client,
122        index_name: String,
123        vector_field: String,
124    ) -> Result<Self, VectorStoreError> {
125        let connection_manager = ConnectionManager::new(client)
126            .await
127            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
128
129        Ok(Self {
130            model,
131            connection_manager,
132            index_name,
133            vector_field,
134            key_prefix: None,
135            metadata_fields: Vec::new(),
136            distance_metric: DistanceMetric::default(),
137        })
138    }
139
140    /// Sets the distance metric the index uses (default [`DistanceMetric::Cosine`]).
141    ///
142    /// This must match the `DISTANCE_METRIC` of the RediSearch index so that
143    /// returned distances are converted to similarity scores correctly. Use
144    /// [`Self::validate_index`] to verify the index agrees.
145    pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
146        self.distance_metric = metric;
147        self
148    }
149
150    /// Sets a key prefix for document keys.
151    ///
152    /// Documents stored via [`InsertDocuments`] will be keyed as `{prefix}{uuid}`.
153    /// This prefix **must** match the index's `PREFIX` configuration for documents
154    /// to be indexed and discoverable by `FT.SEARCH`.
155    pub fn with_key_prefix(mut self, prefix: String) -> Self {
156        self.key_prefix = Some(prefix);
157        self
158    }
159
160    /// Configures metadata fields to extract from documents during insertion.
161    ///
162    /// When documents are inserted, the specified fields are extracted from the
163    /// serialized JSON representation and written as separate hash fields, making
164    /// them available for RediSearch filter queries (TAG, NUMERIC, TEXT). The field
165    /// names must match top-level keys in the serialized document JSON **and** be
166    /// declared in the RediSearch index schema. Calling this method replaces any
167    /// previously configured field list.
168    ///
169    /// Fields that are missing from a document or have null/complex values are
170    /// skipped with a warning log. Reserved field names (`document`, `embedded_text`,
171    /// and the configured vector field) are filtered out with a warning to prevent
172    /// data corruption.
173    ///
174    /// Note: RediSearch TAG fields split stored values on a separator (`,` by
175    /// default). Extracted string values containing the separator will be indexed
176    /// as multiple tags; create the TAG field with a different `SEPARATOR` if your
177    /// values may contain commas.
178    pub fn with_metadata_fields(mut self, fields: Vec<String>) -> Self {
179        self.metadata_fields = filter_reserved_metadata_fields(fields, &self.vector_field);
180        self
181    }
182
183    /// Validates that the configured index exists and is compatible with this store.
184    ///
185    /// Checks, via `FT.INFO`, that:
186    /// - the index exists,
187    /// - every vector field uses the store's configured distance metric, and
188    /// - if a key prefix is configured, the index is defined with that prefix
189    ///   (otherwise inserted documents would never be indexed).
190    ///
191    /// Call this after building the store to fail fast on schema mismatches.
192    pub async fn validate_index(&self) -> Result<(), VectorStoreError> {
193        let mut con = self.connection_manager.clone();
194        let info: redis::Value = redis::cmd("FT.INFO")
195            .arg(&self.index_name)
196            .query_async(&mut con)
197            .await
198            .map_err(|e| {
199                VectorStoreError::DatastoreError(
200                    format!(
201                        "index '{}' not found or FT.INFO failed: {e}",
202                        self.index_name
203                    )
204                    .into(),
205                )
206            })?;
207
208        let mut tokens = Vec::new();
209        Self::flatten_tokens(&info, &mut tokens);
210
211        let expected = self.distance_metric.as_arg();
212        for (i, tok) in tokens.iter().enumerate() {
213            if tok.eq_ignore_ascii_case("distance_metric") {
214                match tokens.get(i + 1) {
215                    Some(m) if m.eq_ignore_ascii_case(expected) => {}
216                    other => {
217                        return Err(VectorStoreError::DatastoreError(
218                            format!(
219                                "index '{}' uses distance metric {:?}, but this store is configured for {}",
220                                self.index_name, other, expected
221                            )
222                            .into(),
223                        ));
224                    }
225                }
226            }
227        }
228
229        if let Some(prefix) = &self.key_prefix {
230            const STOP: &[&str] = &[
231                "default_score",
232                "filter",
233                "language",
234                "language_field",
235                "score_field",
236                "payload_field",
237                "attributes",
238            ];
239            let found = tokens
240                .iter()
241                .position(|t| t == "prefixes")
242                .map(|p| {
243                    tokens[p + 1..]
244                        .iter()
245                        .take_while(|t| !STOP.contains(&t.as_str()))
246                        .any(|t| t == prefix)
247                })
248                .unwrap_or(false);
249            if !found {
250                return Err(VectorStoreError::DatastoreError(
251                    format!(
252                        "index '{}' is not configured with key prefix '{}'",
253                        self.index_name, prefix
254                    )
255                    .into(),
256                ));
257            }
258        }
259
260        Ok(())
261    }
262
263    /// Creates the RediSearch index for this store (HASH, `FLAT`, FLOAT32, COSINE).
264    ///
265    /// Uses the store's index name, vector field, and (if set) key prefix, plus the
266    /// `document` and `embedded_text` TEXT fields. Add any metadata fields you intend
267    /// to filter on. This is a convenience for setups that manage the index in code;
268    /// production deployments may prefer to create the index out of band.
269    pub async fn create_index(
270        &self,
271        dimensions: usize,
272        metadata_fields: &[(String, MetadataFieldType)],
273    ) -> Result<(), VectorStoreError> {
274        let mut con = self.connection_manager.clone();
275        let mut cmd = redis::cmd("FT.CREATE");
276        cmd.arg(&self.index_name).arg("ON").arg("HASH");
277        if let Some(prefix) = &self.key_prefix {
278            cmd.arg("PREFIX").arg(1).arg(prefix);
279        }
280        cmd.arg("SCHEMA")
281            .arg("document")
282            .arg("TEXT")
283            .arg("embedded_text")
284            .arg("TEXT")
285            .arg(&self.vector_field)
286            .arg("VECTOR")
287            .arg("FLAT")
288            .arg(6)
289            .arg("TYPE")
290            .arg("FLOAT32")
291            .arg("DIM")
292            .arg(dimensions)
293            .arg("DISTANCE_METRIC")
294            .arg(self.distance_metric.as_arg());
295        for (name, ty) in metadata_fields {
296            cmd.arg(name).arg(ty.as_arg());
297        }
298        cmd.query_async::<()>(&mut con)
299            .await
300            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
301    }
302
303    /// Deletes documents by their hash keys (the IDs returned by [`Self::top_n_ids`]).
304    ///
305    /// Uses `UNLINK` (non-blocking delete). Returns the number of keys removed.
306    pub async fn delete(&self, ids: &[String]) -> Result<u64, VectorStoreError> {
307        if ids.is_empty() {
308            return Ok(0);
309        }
310        let mut con = self.connection_manager.clone();
311        let mut cmd = redis::cmd("UNLINK");
312        for id in ids {
313            cmd.arg(id);
314        }
315        cmd.query_async::<u64>(&mut con)
316            .await
317            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
318    }
319
320    /// Embeds a query string and returns FLOAT32 LE bytes, rejecting non-finite vectors.
321    async fn embed_query(&self, query: &str) -> Result<Vec<u8>, VectorStoreError> {
322        let embedding = self.model.embed_text(query).await?;
323        if embedding.vec.iter().any(|x| !x.is_finite()) {
324            return Err(VectorStoreError::DatastoreError(
325                "query embedding contains non-finite (NaN/Inf) values".into(),
326            ));
327        }
328        Ok(Self::embedding_to_bytes(&embedding.vec))
329    }
330
331    /// Converts f64 embedding vector to f32 little-endian bytes for Redis VECTOR fields.
332    fn embedding_to_bytes(embedding: &[f64]) -> Vec<u8> {
333        embedding
334            .iter()
335            .flat_map(|&x| (x as f32).to_le_bytes())
336            .collect()
337    }
338
339    /// Extracts a UTF-8 string from a Redis bulk/simple/verbatim string value.
340    fn extract_string(value: &redis::Value) -> Option<String> {
341        match value {
342            redis::Value::BulkString(bytes) => Some(String::from_utf8_lossy(bytes).to_string()),
343            redis::Value::SimpleString(s) => Some(s.clone()),
344            redis::Value::VerbatimString { text, .. } => Some(text.clone()),
345            _ => None,
346        }
347    }
348
349    /// Parses the raw distance value from a Redis score field.
350    ///
351    /// The distance is converted to a similarity score by [`DistanceMetric::score`]
352    /// according to the store's configured metric.
353    fn extract_distance(value: &redis::Value) -> Result<f64, VectorStoreError> {
354        let distance = match value {
355            redis::Value::Double(d) => *d,
356            redis::Value::BulkString(bytes) => {
357                String::from_utf8_lossy(bytes).parse::<f64>().map_err(|e| {
358                    VectorStoreError::DatastoreError(format!("Failed to parse score: {e}").into())
359                })?
360            }
361            redis::Value::SimpleString(s) | redis::Value::VerbatimString { text: s, .. } => {
362                s.parse::<f64>().map_err(|e| {
363                    VectorStoreError::DatastoreError(format!("Failed to parse score: {e}").into())
364                })?
365            }
366            other => {
367                return Err(VectorStoreError::DatastoreError(
368                    format!("Unexpected Redis value type for score: {other:?}").into(),
369                ));
370            }
371        };
372        Ok(distance)
373    }
374
375    /// Parses an FT.SEARCH response into results with deserialized documents.
376    ///
377    /// Documents with empty or unparseable JSON are skipped with a warning rather
378    /// than aborting the entire result set.
379    fn parse_search_response<T>(
380        response: redis::Value,
381    ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
382    where
383        T: for<'a> Deserialize<'a>,
384    {
385        Self::parse_response_generic(response, true).map(|items| {
386            items
387                .into_iter()
388                .filter_map(|(score, id, doc_json)| {
389                    if doc_json.is_empty() {
390                        tracing::warn!(
391                            target: "rig",
392                            id = %id,
393                            "Document field missing or empty in hash, skipping"
394                        );
395                        return None;
396                    }
397                    match serde_json::from_str::<T>(&doc_json) {
398                        Ok(doc) => Some((score, id, doc)),
399                        Err(e) => {
400                            tracing::warn!(
401                                target: "rig",
402                                id = %id,
403                                error = %e,
404                                "Failed to deserialize document, skipping"
405                            );
406                            None
407                        }
408                    }
409                })
410                .collect()
411        })
412    }
413
414    /// Parses an FT.SEARCH response for IDs and scores only.
415    fn parse_search_response_ids(
416        response: redis::Value,
417    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
418        Self::parse_response_generic(response, false).map(|items| {
419            items
420                .into_iter()
421                .map(|(score, id, _)| (score, id))
422                .collect()
423        })
424    }
425
426    /// Generic response parser handling both RESP2 (array) and RESP3 (map)
427    /// `FT.SEARCH` reply shapes, in full-document or ID-only modes.
428    fn parse_response_generic(
429        response: redis::Value,
430        include_document: bool,
431    ) -> Result<Vec<(f64, String, String)>, VectorStoreError> {
432        match response {
433            // RESP3: a map with "results" => [ {id, extra_attributes: {..}}, .. ].
434            redis::Value::Map(pairs) => Self::parse_resp3_map(&pairs, include_document),
435            // RESP2: [count, key1, [field, val, ..], key2, [..], ..].
436            redis::Value::Array(items) => Self::parse_resp2_array(&items, include_document),
437            _ => Err(VectorStoreError::DatastoreError(
438                "Invalid FT.SEARCH response format (expected a RESP2 array or RESP3 map)".into(),
439            )),
440        }
441    }
442
443    /// Parses the RESP2 flat-array `FT.SEARCH` reply.
444    fn parse_resp2_array(
445        items: &[redis::Value],
446        include_document: bool,
447    ) -> Result<Vec<(f64, String, String)>, VectorStoreError> {
448        let count = match items.first() {
449            Some(redis::Value::Int(n)) => *n as usize,
450            _ => {
451                return Err(VectorStoreError::DatastoreError(
452                    "Invalid response format: expected count as first element".into(),
453                ));
454            }
455        };
456
457        if count == 0 {
458            return Ok(Vec::new());
459        }
460
461        let mut results = Vec::with_capacity(count);
462
463        let mut iter = items.iter().skip(1);
464        while let Some(key_val) = iter.next() {
465            let id = match Self::extract_string(key_val) {
466                Some(id) => id,
467                None => {
468                    iter.next();
469                    continue;
470                }
471            };
472
473            let fields_val = match iter.next() {
474                Some(redis::Value::Array(fields)) => fields,
475                _ => continue,
476            };
477
478            let mut distance = 0.0;
479            let mut score_found = false;
480            let mut document_json = String::new();
481
482            for chunk in fields_val.chunks(2) {
483                let [name_val, value_val] = chunk else {
484                    continue;
485                };
486                let field_name = match Self::extract_string(name_val) {
487                    Some(name) => name,
488                    None => continue,
489                };
490
491                if field_name == "__vector_score" {
492                    distance = Self::extract_distance(value_val)?;
493                    score_found = true;
494                } else if include_document && field_name == "document" {
495                    match Self::extract_string(value_val) {
496                        Some(json) => document_json = json,
497                        None => {
498                            tracing::warn!(
499                                target: "rig",
500                                id = %id,
501                                "Document field present but could not be extracted as string"
502                            );
503                        }
504                    }
505                }
506            }
507
508            if !score_found {
509                tracing::warn!(
510                    target: "rig",
511                    id = %id,
512                    "__vector_score field missing from search result, defaulting to 0.0"
513                );
514            }
515
516            results.push((distance, id, document_json));
517        }
518
519        Ok(results)
520    }
521
522    /// Parses the RESP3 map-shaped `FT.SEARCH` reply.
523    fn parse_resp3_map(
524        pairs: &[(redis::Value, redis::Value)],
525        include_document: bool,
526    ) -> Result<Vec<(f64, String, String)>, VectorStoreError> {
527        let entries = pairs
528            .iter()
529            .find_map(|(k, v)| match (Self::extract_string(k), v) {
530                (Some(name), redis::Value::Array(items)) if name == "results" => Some(items),
531                _ => None,
532            });
533
534        let Some(entries) = entries else {
535            // No "results" key (e.g. total_results 0) -> no matches.
536            return Ok(Vec::new());
537        };
538
539        let mut results = Vec::with_capacity(entries.len());
540        for entry in entries {
541            let redis::Value::Map(fields) = entry else {
542                continue;
543            };
544
545            let mut id = String::new();
546            let mut distance = 0.0;
547            let mut score_found = false;
548            let mut document_json = String::new();
549
550            for (k, v) in fields {
551                match Self::extract_string(k).as_deref() {
552                    Some("id") => {
553                        if let Some(s) = Self::extract_string(v) {
554                            id = s;
555                        }
556                    }
557                    Some("extra_attributes") => {
558                        if let redis::Value::Map(attrs) = v {
559                            for (ak, av) in attrs {
560                                match Self::extract_string(ak).as_deref() {
561                                    Some("__vector_score") => {
562                                        distance = Self::extract_distance(av)?;
563                                        score_found = true;
564                                    }
565                                    Some("document") if include_document => {
566                                        if let Some(s) = Self::extract_string(av) {
567                                            document_json = s;
568                                        }
569                                    }
570                                    _ => {}
571                                }
572                            }
573                        }
574                    }
575                    _ => {}
576                }
577            }
578
579            if !score_found {
580                tracing::warn!(
581                    target: "rig",
582                    id = %id,
583                    "__vector_score field missing from search result, defaulting to 0.0"
584                );
585            }
586
587            results.push((distance, id, document_json));
588        }
589
590        Ok(results)
591    }
592
593    /// Recursively flattens a Redis reply into its scalar string tokens, in order.
594    /// Used to inspect `FT.INFO` output without depending on its exact shape.
595    fn flatten_tokens(value: &redis::Value, out: &mut Vec<String>) {
596        match value {
597            redis::Value::Array(items) | redis::Value::Set(items) => {
598                for v in items {
599                    Self::flatten_tokens(v, out);
600                }
601            }
602            redis::Value::Map(pairs) => {
603                for (k, v) in pairs {
604                    Self::flatten_tokens(k, out);
605                    Self::flatten_tokens(v, out);
606                }
607            }
608            redis::Value::BulkString(bytes) => out.push(String::from_utf8_lossy(bytes).to_string()),
609            redis::Value::SimpleString(s) => out.push(s.clone()),
610            redis::Value::VerbatimString { text, .. } => out.push(text.clone()),
611            redis::Value::Int(i) => out.push(i.to_string()),
612            redis::Value::Double(d) => out.push(d.to_string()),
613            _ => {}
614        }
615    }
616
617    /// Builds and executes an FT.SEARCH KNN query.
618    async fn execute_search(
619        &self,
620        vector_bytes: Vec<u8>,
621        req: &VectorSearchRequest<Filter>,
622        include_document: bool,
623    ) -> Result<redis::Value, VectorStoreError> {
624        let mut con = self.connection_manager.clone();
625
626        let filter_str = req
627            .filter()
628            .as_ref()
629            .map(|f| f.clone().into_inner())
630            .unwrap_or_else(|| "*".to_string());
631
632        let knn_query = format!(
633            "{}=>[KNN {} @{} $vec AS __vector_score]",
634            filter_str,
635            req.samples(),
636            self.vector_field
637        );
638
639        let mut cmd = redis::cmd("FT.SEARCH");
640        cmd.arg(&self.index_name)
641            .arg(&knn_query)
642            .arg("PARAMS")
643            .arg(2)
644            .arg("vec")
645            .arg(vector_bytes)
646            .arg("SORTBY")
647            .arg("__vector_score")
648            .arg("RETURN");
649
650        if include_document {
651            cmd.arg(2).arg("__vector_score").arg("document");
652        } else {
653            cmd.arg(1).arg("__vector_score");
654        }
655
656        cmd.arg("DIALECT").arg(2);
657
658        // Always specify LIMIT to override RediSearch's default of 10 results.
659        cmd.arg("LIMIT").arg(0).arg(req.samples());
660
661        cmd.query_async(&mut con)
662            .await
663            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
664    }
665
666    /// Converts a JSON value to a string suitable for a flat Redis hash field.
667    ///
668    /// Strings are stored unquoted, numbers/booleans use their string form
669    /// (`1`/`0` for booleans). Null/array/object return `None`.
670    fn json_value_to_hash_field(value: &serde_json::Value) -> Option<String> {
671        match value {
672            serde_json::Value::String(s) => Some(s.clone()),
673            serde_json::Value::Number(n) => Some(n.to_string()),
674            serde_json::Value::Bool(b) => Some(if *b { "1".to_string() } else { "0".to_string() }),
675            serde_json::Value::Null
676            | serde_json::Value::Array(_)
677            | serde_json::Value::Object(_) => None,
678        }
679    }
680}
681
682impl<Model> InsertDocuments for RedisVectorStore<Model>
683where
684    Model: EmbeddingModel + Send + Sync,
685{
686    /// Inserts documents with their precomputed embeddings into Redis.
687    ///
688    /// Each embedding in [`OneOrMany<Embedding>`] produces a separate Redis hash
689    /// keyed by `{prefix}{uuid}`. All hashes for a document share the same serialized
690    /// JSON in the `document` field but have distinct `embedded_text` values.
691    async fn insert_documents<Doc: Serialize + Embed + Send>(
692        &self,
693        documents: Vec<(Doc, OneOrMany<Embedding>)>,
694    ) -> Result<(), VectorStoreError> {
695        let mut con = self.connection_manager.clone();
696        let mut pipe = redis::pipe();
697
698        for (document, embeddings) in &documents {
699            let json_value = serde_json::to_value(document)?;
700            let json_document = json_value.to_string();
701
702            // Extract configured metadata fields from the document JSON.
703            let metadata: Vec<(String, String)> = if self.metadata_fields.is_empty() {
704                Vec::new()
705            } else {
706                self.metadata_fields
707                    .iter()
708                    .filter_map(|field_name| {
709                        let value = json_value.get(field_name)?;
710                        match Self::json_value_to_hash_field(value) {
711                            Some(hash_value) => Some((field_name.clone(), hash_value)),
712                            None => {
713                                tracing::warn!(
714                                    target: "rig",
715                                    field = %field_name,
716                                    value_type = %value,
717                                    "Metadata field has unsupported type (null/array/object), skipping"
718                                );
719                                None
720                            }
721                        }
722                    })
723                    .collect()
724            };
725
726            for embedding in embeddings.iter() {
727                let id = if let Some(ref prefix) = self.key_prefix {
728                    format!("{}{}", prefix, uuid::Uuid::new_v4())
729                } else {
730                    uuid::Uuid::new_v4().to_string()
731                };
732                let embedding_bytes = Self::embedding_to_bytes(&embedding.vec);
733
734                let cmd = pipe
735                    .cmd("HSET")
736                    .arg(&id)
737                    .arg("document")
738                    .arg(json_document.as_bytes())
739                    .arg("embedded_text")
740                    .arg(embedding.document.as_bytes())
741                    .arg(&self.vector_field)
742                    .arg(embedding_bytes);
743
744                for (field_name, field_value) in &metadata {
745                    cmd.arg(field_name).arg(field_value.as_bytes());
746                }
747
748                cmd.ignore();
749            }
750        }
751
752        pipe.query_async::<()>(&mut con)
753            .await
754            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
755
756        tracing::debug!(
757            target: "rig",
758            index = %self.index_name,
759            count = documents.len(),
760            metadata_fields = ?self.metadata_fields,
761            "Inserted documents into Redis vector store"
762        );
763
764        Ok(())
765    }
766}
767
768impl<M> VectorStoreIndex for RedisVectorStore<M>
769where
770    M: EmbeddingModel + Send + Sync,
771{
772    type Filter = Filter;
773
774    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
775        &self,
776        req: VectorSearchRequest<Self::Filter>,
777    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
778        if req.samples() == 0 {
779            return Ok(Vec::new());
780        }
781        let vector_bytes = self.embed_query(req.query()).await?;
782
783        let response = self.execute_search(vector_bytes, &req, true).await?;
784        let mut results = Self::parse_search_response::<T>(response)?
785            .into_iter()
786            .map(|(distance, id, doc)| (self.distance_metric.score(distance), id, doc))
787            .collect::<Vec<_>>();
788
789        if let Some(threshold) = req.threshold() {
790            results.retain(|(score, _, _)| *score >= threshold);
791        }
792
793        tracing::debug!(
794            target: "rig",
795            index = %self.index_name,
796            query = %req.query(),
797            "Selected documents: {}",
798            results.iter().map(|(score, id, _)| format!("{id} ({score:.4})")).collect::<Vec<_>>().join(", ")
799        );
800
801        Ok(results)
802    }
803
804    async fn top_n_ids(
805        &self,
806        req: VectorSearchRequest<Self::Filter>,
807    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
808        if req.samples() == 0 {
809            return Ok(Vec::new());
810        }
811        let vector_bytes = self.embed_query(req.query()).await?;
812
813        let response = self.execute_search(vector_bytes, &req, false).await?;
814        let mut results = Self::parse_search_response_ids(response)?
815            .into_iter()
816            .map(|(distance, id)| (self.distance_metric.score(distance), id))
817            .collect::<Vec<_>>();
818
819        if let Some(threshold) = req.threshold() {
820            results.retain(|(score, _)| *score >= threshold);
821        }
822
823        tracing::debug!(
824            target: "rig",
825            index = %self.index_name,
826            query = %req.query(),
827            "Selected document IDs: {}",
828            results.iter().map(|(score, id)| format!("{id} ({score:.4})")).collect::<Vec<_>>().join(", ")
829        );
830
831        Ok(results)
832    }
833}
834
835impl<M> VectorStoreIndexDyn for RedisVectorStore<M>
836where
837    M: EmbeddingModel + Sync + Send,
838{
839    fn top_n<'a>(
840        &'a self,
841        req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
842    ) -> WasmBoxedFuture<'a, TopNResults> {
843        Box::pin(async move {
844            let req = req.try_map_filter(Filter::try_from)?;
845            let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
846            Ok(results)
847        })
848    }
849
850    fn top_n_ids<'a>(
851        &'a self,
852        req: VectorSearchRequest<CoreFilter<serde_json::Value>>,
853    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
854        Box::pin(async move {
855            let req = req.try_map_filter(Filter::try_from)?;
856            let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
857            Ok(results)
858        })
859    }
860}
861
862/// Filters out reserved hash field names (`document`, `embedded_text`, and the
863/// vector field) from a configured metadata field list, emitting a warning for
864/// each removed name to prevent overwriting reserved hash fields.
865fn filter_reserved_metadata_fields(fields: Vec<String>, vector_field: &str) -> Vec<String> {
866    let reserved = ["document", "embedded_text", vector_field];
867    fields
868        .into_iter()
869        .filter(|f| {
870            if reserved.contains(&f.as_str()) {
871                tracing::warn!(
872                    target: "rig",
873                    field = %f,
874                    "Metadata field name conflicts with reserved hash field, skipping"
875                );
876                false
877            } else {
878                true
879            }
880        })
881        .collect()
882}
883
884/// RediSearch vector distance metric. Determines how the returned distance is
885/// converted to a similarity score (higher = more similar).
886#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
887pub enum DistanceMetric {
888    /// Cosine distance. Score = `1 - distance` (1 = identical, -1 = opposite).
889    #[default]
890    Cosine,
891    /// Squared Euclidean (L2) distance in `[0, inf)`. Score = `1 / (1 + distance)`
892    /// (1 = identical, approaching 0 as vectors get farther apart).
893    L2,
894    /// Inner-product distance. RediSearch returns `1 - inner_product`, so
895    /// Score = `1 - distance` (equal to the inner product; higher = more similar).
896    InnerProduct,
897}
898
899impl DistanceMetric {
900    /// The `DISTANCE_METRIC` argument value for `FT.CREATE`.
901    fn as_arg(self) -> &'static str {
902        match self {
903            DistanceMetric::Cosine => "COSINE",
904            DistanceMetric::L2 => "L2",
905            DistanceMetric::InnerProduct => "IP",
906        }
907    }
908
909    /// Converts a RediSearch distance into a similarity score where higher means
910    /// more similar. The conversion is monotonically decreasing in `distance`, so
911    /// it preserves RediSearch's nearest-first ordering for every metric.
912    fn score(self, distance: f64) -> f64 {
913        match self {
914            DistanceMetric::Cosine | DistanceMetric::InnerProduct => 1.0 - distance,
915            DistanceMetric::L2 => 1.0 / (1.0 + distance),
916        }
917    }
918}
919
920/// RediSearch field type for a metadata field declared via [`RedisVectorStore::create_index`].
921#[derive(Debug, Clone, Copy, PartialEq, Eq)]
922pub enum MetadataFieldType {
923    /// Exact-match tag field (`@field:{value}`).
924    Tag,
925    /// Numeric field supporting range filters (`@field:[min max]`).
926    Numeric,
927    /// Full-text field (`@field:(tokens)`).
928    Text,
929}
930
931impl MetadataFieldType {
932    fn as_arg(self) -> &'static str {
933        match self {
934            MetadataFieldType::Tag => "TAG",
935            MetadataFieldType::Numeric => "NUMERIC",
936            MetadataFieldType::Text => "TEXT",
937        }
938    }
939}
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944    use rig_core::embeddings::embedding::EmbeddingError;
945
946    /// Minimal embedding model used only to name a concrete `RedisVectorStore`
947    /// type for calling its self-less associated helpers in unit tests.
948    struct FakeModel;
949
950    impl EmbeddingModel for FakeModel {
951        const MAX_DOCUMENTS: usize = 1024;
952        type Client = ();
953
954        fn make(_client: &Self::Client, _model: impl Into<String>, _dims: Option<usize>) -> Self {
955            FakeModel
956        }
957
958        fn ndims(&self) -> usize {
959            3
960        }
961
962        async fn embed_texts(
963            &self,
964            _texts: impl IntoIterator<Item = String> + Send,
965        ) -> Result<Vec<Embedding>, EmbeddingError> {
966            Ok(Vec::new())
967        }
968    }
969
970    type Store = RedisVectorStore<FakeModel>;
971
972    fn bulk(s: &str) -> redis::Value {
973        redis::Value::BulkString(s.as_bytes().to_vec())
974    }
975
976    #[test]
977    fn reserved_metadata_fields_are_filtered() {
978        let kept = filter_reserved_metadata_fields(
979            vec![
980                "category".to_string(),
981                "document".to_string(),
982                "embedded_text".to_string(),
983                "embedding".to_string(),
984                "price".to_string(),
985            ],
986            "embedding",
987        );
988        assert_eq!(kept, vec!["category".to_string(), "price".to_string()]);
989    }
990
991    #[test]
992    fn json_value_to_hash_field_covers_all_types() {
993        assert_eq!(
994            Store::json_value_to_hash_field(&serde_json::json!("hello")),
995            Some("hello".to_string())
996        );
997        assert_eq!(
998            Store::json_value_to_hash_field(&serde_json::json!(3)),
999            Some("3".to_string())
1000        );
1001        assert_eq!(
1002            Store::json_value_to_hash_field(&serde_json::json!(true)),
1003            Some("1".to_string())
1004        );
1005        assert_eq!(
1006            Store::json_value_to_hash_field(&serde_json::json!(false)),
1007            Some("0".to_string())
1008        );
1009        assert_eq!(
1010            Store::json_value_to_hash_field(&serde_json::Value::Null),
1011            None
1012        );
1013        assert_eq!(
1014            Store::json_value_to_hash_field(&serde_json::json!([1, 2])),
1015            None
1016        );
1017        assert_eq!(
1018            Store::json_value_to_hash_field(&serde_json::json!({"a": 1})),
1019            None
1020        );
1021    }
1022
1023    #[test]
1024    fn embedding_to_bytes_is_float32_le() {
1025        let bytes = Store::embedding_to_bytes(&[1.0_f64]);
1026        assert_eq!(bytes, vec![0, 0, 128, 63]); // 1.0_f32 little-endian
1027    }
1028
1029    #[test]
1030    fn parse_search_response_skips_empty_documents() {
1031        // count=2: doc:1 has valid JSON, doc:2 has an empty document field.
1032        let response = redis::Value::Array(vec![
1033            redis::Value::Int(2),
1034            bulk("doc:1"),
1035            redis::Value::Array(vec![
1036                bulk("__vector_score"),
1037                bulk("0.1"),
1038                bulk("document"),
1039                bulk("{\"a\":1}"),
1040            ]),
1041            bulk("doc:2"),
1042            redis::Value::Array(vec![
1043                bulk("__vector_score"),
1044                bulk("0.2"),
1045                bulk("document"),
1046                bulk(""),
1047            ]),
1048        ]);
1049
1050        let results =
1051            Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1052        assert_eq!(results.len(), 1);
1053        assert_eq!(results[0].1, "doc:1");
1054        assert!((results[0].0 - 0.1).abs() < 1e-9); // raw distance, converted by the metric later
1055    }
1056
1057    #[test]
1058    fn parse_search_response_empty_when_count_zero() {
1059        let response = redis::Value::Array(vec![redis::Value::Int(0)]);
1060        let results =
1061            Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1062        assert!(results.is_empty());
1063    }
1064
1065    #[test]
1066    fn parse_resp3_map_response() {
1067        // RESP3 FT.SEARCH reply shape: a map with a "results" array of per-doc maps.
1068        let response = redis::Value::Map(vec![
1069            (bulk("attributes"), redis::Value::Array(vec![])),
1070            (bulk("format"), bulk("STRING")),
1071            (
1072                bulk("results"),
1073                redis::Value::Array(vec![redis::Value::Map(vec![
1074                    (bulk("id"), bulk("d:1")),
1075                    (
1076                        bulk("extra_attributes"),
1077                        redis::Value::Map(vec![
1078                            (bulk("__vector_score"), bulk("0.1")),
1079                            (bulk("document"), bulk("{\"a\":1}")),
1080                        ]),
1081                    ),
1082                ])]),
1083            ),
1084            (bulk("total_results"), redis::Value::Int(1)),
1085        ]);
1086
1087        let results =
1088            Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1089        assert_eq!(results.len(), 1);
1090        assert_eq!(results[0].1, "d:1");
1091        assert!((results[0].0 - 0.1).abs() < 1e-9); // raw distance
1092    }
1093
1094    #[test]
1095    fn parse_resp3_map_empty_results() {
1096        let response = redis::Value::Map(vec![
1097            (bulk("results"), redis::Value::Array(vec![])),
1098            (bulk("total_results"), redis::Value::Int(0)),
1099        ]);
1100        let results =
1101            Store::parse_search_response::<serde_json::Value>(response).expect("parse ok");
1102        assert!(results.is_empty());
1103    }
1104
1105    #[test]
1106    fn distance_metric_score_conversions() {
1107        // Cosine: 1 - distance, range [-1, 1].
1108        assert!((DistanceMetric::Cosine.score(0.0) - 1.0).abs() < 1e-9);
1109        assert!((DistanceMetric::Cosine.score(2.0) - (-1.0)).abs() < 1e-9);
1110        // Inner product: 1 - distance (== the dot product).
1111        assert!((DistanceMetric::InnerProduct.score(0.0) - 1.0).abs() < 1e-9);
1112        assert!((DistanceMetric::InnerProduct.score(0.5) - 0.5).abs() < 1e-9);
1113        // L2: 1 / (1 + distance), range (0, 1].
1114        assert!((DistanceMetric::L2.score(0.0) - 1.0).abs() < 1e-9);
1115        assert!((DistanceMetric::L2.score(3.0) - 0.25).abs() < 1e-9);
1116    }
1117
1118    #[test]
1119    fn distance_metric_score_is_monotonic_decreasing() {
1120        for metric in [
1121            DistanceMetric::Cosine,
1122            DistanceMetric::L2,
1123            DistanceMetric::InnerProduct,
1124        ] {
1125            assert!(
1126                metric.score(0.1) > metric.score(0.5),
1127                "{metric:?} score must decrease as distance grows"
1128            );
1129        }
1130    }
1131
1132    #[test]
1133    fn distance_metric_as_arg() {
1134        assert_eq!(DistanceMetric::Cosine.as_arg(), "COSINE");
1135        assert_eq!(DistanceMetric::L2.as_arg(), "L2");
1136        assert_eq!(DistanceMetric::InnerProduct.as_arg(), "IP");
1137    }
1138}