Skip to main content

redis_vl/extensions/
cache.rs

1//! Semantic and embedding cache extensions.
2//!
3//! [`EmbeddingsCache`](crate::EmbeddingsCache) provides deterministic cache
4//! lookups for embedding vectors keyed by content and model name.
5//! [`SemanticCache`](crate::SemanticCache) provides LLM response caching with
6//! vector similarity lookup — when a new prompt is semantically similar to a
7//! cached prompt (within a configurable distance threshold), the cached
8//! response is returned.
9//!
10//! Both caches are Redis-backed and support sync and async operations.
11
12use std::{collections::HashMap, sync::Arc};
13
14use chrono::Utc;
15use redis::AsyncCommands;
16use serde::{Deserialize, Serialize};
17use serde_json::{Map, Number, Value, json};
18use sha2::{Digest, Sha256};
19
20use crate::{
21    error::Result,
22    filter::FilterExpression,
23    index::{AsyncSearchIndex, QueryOutput, RedisConnectionInfo, SearchIndex},
24    query::{Vector, VectorRangeQuery},
25    schema::VectorDataType,
26    vectorizers::Vectorizer,
27};
28
29const SEMANTIC_ENTRY_ID_FIELD: &str = "entry_id";
30const SEMANTIC_PROMPT_FIELD: &str = "prompt";
31const SEMANTIC_RESPONSE_FIELD: &str = "response";
32const SEMANTIC_VECTOR_FIELD: &str = "prompt_vector";
33const SEMANTIC_INSERTED_AT_FIELD: &str = "inserted_at";
34const SEMANTIC_UPDATED_AT_FIELD: &str = "updated_at";
35const SEMANTIC_METADATA_FIELD: &str = "metadata";
36const SEMANTIC_KEY_FIELD: &str = "key";
37
38/// Shared configuration for cache-backed extensions.
39#[derive(Debug, Clone)]
40pub struct CacheConfig {
41    /// Cache name or key namespace.
42    pub name: String,
43    /// Redis connection settings.
44    pub connection: RedisConnectionInfo,
45    /// Optional TTL in seconds.
46    pub ttl_seconds: Option<u64>,
47}
48
49impl CacheConfig {
50    /// Creates a new cache configuration with no default TTL.
51    pub fn new(name: impl Into<String>, redis_url: impl Into<String>) -> Self {
52        Self {
53            name: name.into(),
54            connection: RedisConnectionInfo::new(redis_url),
55            ttl_seconds: None,
56        }
57    }
58
59    /// Adds a default TTL to the cache configuration.
60    #[must_use]
61    pub fn with_ttl(mut self, ttl_seconds: u64) -> Self {
62        self.ttl_seconds = Some(ttl_seconds);
63        self
64    }
65}
66
67impl Default for CacheConfig {
68    fn default() -> Self {
69        Self::new("embedcache", "redis://127.0.0.1:6379")
70    }
71}
72
73/// Entry stored in the embeddings cache.
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct EmbeddingCacheEntry {
76    /// Deterministic entry identifier derived from content and model name.
77    pub entry_id: String,
78    /// Original content that was embedded.
79    pub content: String,
80    /// Embedding model name.
81    pub model_name: String,
82    /// Embedding vector payload.
83    pub embedding: Vec<f32>,
84    /// Optional arbitrary metadata stored alongside the vector.
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub metadata: Option<Value>,
87}
88
89/// Batch input item for the embeddings cache.
90#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
91pub struct EmbeddingCacheItem {
92    /// Original content that was embedded.
93    pub content: String,
94    /// Embedding model name.
95    pub model_name: String,
96    /// Embedding vector payload.
97    pub embedding: Vec<f32>,
98    /// Optional arbitrary metadata stored alongside the vector.
99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub metadata: Option<Value>,
101}
102
103/// Semantic cache backed by a Redis Search vector index.
104#[derive(Clone)]
105pub struct SemanticCache {
106    /// Cache configuration.
107    pub config: CacheConfig,
108    /// Distance threshold used for semantic hits.
109    pub distance_threshold: f32,
110    /// Prompt embedding dimensions stored in Redis.
111    pub vector_dimensions: usize,
112    /// Vector element data type used for the index schema.
113    pub dtype: VectorDataType,
114    /// Underlying search index.
115    pub index: SearchIndex,
116    vectorizer: Option<Arc<dyn Vectorizer>>,
117    return_fields: Vec<String>,
118}
119
120impl SemanticCache {
121    /// Creates a new semantic cache with the default reserved schema.
122    pub fn new(
123        config: CacheConfig,
124        distance_threshold: f32,
125        vector_dimensions: usize,
126    ) -> Result<Self> {
127        Self::with_options(
128            config,
129            distance_threshold,
130            vector_dimensions,
131            VectorDataType::Float32,
132            &[],
133        )
134    }
135
136    /// Creates a new semantic cache with a specific vector data type.
137    pub fn with_dtype(
138        config: CacheConfig,
139        distance_threshold: f32,
140        vector_dimensions: usize,
141        dtype: VectorDataType,
142    ) -> Result<Self> {
143        Self::with_options(config, distance_threshold, vector_dimensions, dtype, &[])
144    }
145
146    /// Creates a new semantic cache with additional filterable schema fields.
147    pub fn with_filterable_fields(
148        config: CacheConfig,
149        distance_threshold: f32,
150        vector_dimensions: usize,
151        filterable_fields: &[Value],
152    ) -> Result<Self> {
153        Self::with_options(
154            config,
155            distance_threshold,
156            vector_dimensions,
157            VectorDataType::Float32,
158            filterable_fields,
159        )
160    }
161
162    /// Creates a new semantic cache with full control over dtype and filterable fields.
163    pub fn with_options(
164        config: CacheConfig,
165        distance_threshold: f32,
166        vector_dimensions: usize,
167        dtype: VectorDataType,
168        filterable_fields: &[Value],
169    ) -> Result<Self> {
170        validate_distance_threshold(distance_threshold)?;
171        if vector_dimensions == 0 {
172            return Err(crate::Error::InvalidInput(
173                "vector_dimensions must be greater than zero".to_owned(),
174            ));
175        }
176        validate_filterable_fields(filterable_fields)?;
177
178        let schema =
179            semantic_cache_schema(&config.name, vector_dimensions, dtype, filterable_fields);
180        let index = SearchIndex::from_json_value(schema, config.connection.redis_url.clone())?;
181        if !index.exists().unwrap_or(false) {
182            index.create_with_options(false, false)?;
183        }
184
185        Ok(Self {
186            config,
187            distance_threshold,
188            vector_dimensions,
189            dtype,
190            index,
191            vectorizer: None,
192            return_fields: default_semantic_return_fields(),
193        })
194    }
195
196    /// Attaches a synchronous vectorizer used when callers pass prompts instead of vectors.
197    #[must_use]
198    pub fn with_vectorizer<V>(mut self, vectorizer: V) -> Self
199    where
200        V: Vectorizer + 'static,
201    {
202        self.vectorizer = Some(Arc::new(vectorizer));
203        self
204    }
205
206    /// Attaches the default HuggingFace local vectorizer.
207    ///
208    /// This uses the `AllMiniLML6V2` model from [`fastembed`] which runs
209    /// locally via ONNX Runtime and requires no API key. The model is
210    /// downloaded from HuggingFace Hub on first use.
211    ///
212    /// # Errors
213    ///
214    /// Returns an error if the model cannot be loaded.
215    #[cfg(feature = "hf-local")]
216    pub fn with_default_vectorizer(self) -> Result<Self> {
217        let vectorizer = crate::vectorizers::HuggingFaceTextVectorizer::new(Default::default())?;
218        Ok(self.with_vectorizer(vectorizer))
219    }
220
221    /// Replaces the vectorizer used for prompt embedding.
222    pub fn set_vectorizer<V>(&mut self, vectorizer: V)
223    where
224        V: Vectorizer + 'static,
225    {
226        self.vectorizer = Some(Arc::new(vectorizer));
227    }
228
229    /// Returns the configured default TTL for semantic cache entries.
230    pub fn ttl(&self) -> Option<u64> {
231        self.config.ttl_seconds
232    }
233
234    /// Sets or clears the default TTL for semantic cache entries.
235    pub fn set_ttl(&mut self, ttl_seconds: Option<u64>) {
236        self.config.ttl_seconds = ttl_seconds;
237    }
238
239    /// Updates the semantic distance threshold.
240    pub fn set_threshold(&mut self, distance_threshold: f32) -> Result<()> {
241        validate_distance_threshold(distance_threshold)?;
242        self.distance_threshold = distance_threshold;
243        Ok(())
244    }
245
246    /// Stores a semantic cache entry and returns its Redis key.
247    pub fn store(
248        &self,
249        prompt: &str,
250        response: &str,
251        vector: Option<&[f32]>,
252        metadata: Option<Value>,
253        filters: Option<Map<String, Value>>,
254        ttl_seconds: Option<u64>,
255    ) -> Result<String> {
256        if let Some(metadata) = metadata.as_ref() {
257            validate_metadata(metadata)?;
258        }
259
260        let vector = self.resolve_vector(prompt, vector)?;
261        let timestamp = current_timestamp();
262        let entry_id = semantic_entry_id(prompt, filters.as_ref());
263        let mut record = Map::new();
264        record.insert(SEMANTIC_ENTRY_ID_FIELD.to_owned(), Value::String(entry_id));
265        record.insert(
266            SEMANTIC_PROMPT_FIELD.to_owned(),
267            Value::String(prompt.to_owned()),
268        );
269        record.insert(
270            SEMANTIC_RESPONSE_FIELD.to_owned(),
271            Value::String(response.to_owned()),
272        );
273        record.insert(
274            SEMANTIC_VECTOR_FIELD.to_owned(),
275            Value::Array(
276                vector
277                    .iter()
278                    .copied()
279                    .map(|value| number_value(f64::from(value)))
280                    .collect(),
281            ),
282        );
283        record.insert(
284            SEMANTIC_INSERTED_AT_FIELD.to_owned(),
285            number_value(timestamp),
286        );
287        record.insert(
288            SEMANTIC_UPDATED_AT_FIELD.to_owned(),
289            number_value(timestamp),
290        );
291        if let Some(metadata) = metadata {
292            record.insert(SEMANTIC_METADATA_FIELD.to_owned(), metadata);
293        }
294        if let Some(filters) = filters {
295            for (key, value) in filters {
296                record.insert(key, value);
297            }
298        }
299
300        let keys = self.index.load(
301            &[Value::Object(record)],
302            SEMANTIC_ENTRY_ID_FIELD,
303            ttl_seconds
304                .or(self.config.ttl_seconds)
305                .map(|value| value as i64),
306        )?;
307        Ok(keys.into_iter().next().unwrap_or_default())
308    }
309
310    /// Stores a semantic cache entry asynchronously and returns its Redis key.
311    pub async fn astore(
312        &self,
313        prompt: &str,
314        response: &str,
315        vector: Option<&[f32]>,
316        metadata: Option<Value>,
317        filters: Option<Map<String, Value>>,
318        ttl_seconds: Option<u64>,
319    ) -> Result<String> {
320        if let Some(metadata) = metadata.as_ref() {
321            validate_metadata(metadata)?;
322        }
323
324        let vector = self.resolve_vector(prompt, vector)?;
325        let timestamp = current_timestamp();
326        let entry_id = semantic_entry_id(prompt, filters.as_ref());
327        let mut record = Map::new();
328        record.insert(SEMANTIC_ENTRY_ID_FIELD.to_owned(), Value::String(entry_id));
329        record.insert(
330            SEMANTIC_PROMPT_FIELD.to_owned(),
331            Value::String(prompt.to_owned()),
332        );
333        record.insert(
334            SEMANTIC_RESPONSE_FIELD.to_owned(),
335            Value::String(response.to_owned()),
336        );
337        record.insert(
338            SEMANTIC_VECTOR_FIELD.to_owned(),
339            Value::Array(
340                vector
341                    .iter()
342                    .copied()
343                    .map(|value| number_value(f64::from(value)))
344                    .collect(),
345            ),
346        );
347        record.insert(
348            SEMANTIC_INSERTED_AT_FIELD.to_owned(),
349            number_value(timestamp),
350        );
351        record.insert(
352            SEMANTIC_UPDATED_AT_FIELD.to_owned(),
353            number_value(timestamp),
354        );
355        if let Some(metadata) = metadata {
356            record.insert(SEMANTIC_METADATA_FIELD.to_owned(), metadata);
357        }
358        if let Some(filters) = filters {
359            for (key, value) in filters {
360                record.insert(key, value);
361            }
362        }
363
364        let keys = self
365            .async_index()
366            .load(
367                &[Value::Object(record)],
368                SEMANTIC_ENTRY_ID_FIELD,
369                ttl_seconds
370                    .or(self.config.ttl_seconds)
371                    .map(|value| value as i64),
372            )
373            .await?;
374        Ok(keys.into_iter().next().unwrap_or_default())
375    }
376
377    /// Checks the semantic cache for similar prompts or a supplied vector.
378    pub fn check(
379        &self,
380        prompt: Option<&str>,
381        vector: Option<&[f32]>,
382        num_results: usize,
383        return_fields: Option<&[&str]>,
384        filter_expression: Option<FilterExpression>,
385        distance_threshold: Option<f32>,
386    ) -> Result<Vec<Map<String, Value>>> {
387        let vector = self.resolve_query_vector(prompt, vector)?;
388        let threshold = distance_threshold.unwrap_or(self.distance_threshold);
389        validate_distance_threshold(threshold)?;
390        let mut query = VectorRangeQuery::new(
391            Vector::new(vector.clone()),
392            SEMANTIC_VECTOR_FIELD,
393            threshold,
394        )
395        .paging(0, num_results)
396        .with_return_fields(self.return_fields.iter().map(String::as_str));
397        if let Some(filter_expression) = filter_expression {
398            query = query.with_filter(filter_expression);
399        }
400
401        let hits = process_semantic_hits(
402            query_output_documents(self.index.query(&query)?)?,
403            return_fields,
404        )?;
405        self.refresh_ttl_sync(&hits)?;
406        Ok(hits)
407    }
408
409    /// Checks the semantic cache for similar prompts or a supplied vector asynchronously.
410    pub async fn acheck(
411        &self,
412        prompt: Option<&str>,
413        vector: Option<&[f32]>,
414        num_results: usize,
415        return_fields: Option<&[&str]>,
416        filter_expression: Option<FilterExpression>,
417        distance_threshold: Option<f32>,
418    ) -> Result<Vec<Map<String, Value>>> {
419        let vector = self.resolve_query_vector(prompt, vector)?;
420        let threshold = distance_threshold.unwrap_or(self.distance_threshold);
421        validate_distance_threshold(threshold)?;
422        let mut query = VectorRangeQuery::new(
423            Vector::new(vector.clone()),
424            SEMANTIC_VECTOR_FIELD,
425            threshold,
426        )
427        .paging(0, num_results)
428        .with_return_fields(self.return_fields.iter().map(String::as_str));
429        if let Some(filter_expression) = filter_expression {
430            query = query.with_filter(filter_expression);
431        }
432
433        let hits = process_semantic_hits(
434            query_output_documents(self.async_index().query(&query).await?)?,
435            return_fields,
436        )?;
437        self.refresh_ttl_async(&hits).await?;
438        Ok(hits)
439    }
440
441    /// Updates cached fields for a stored semantic cache entry and refreshes TTL.
442    pub fn update(&self, key: &str, fields: Map<String, Value>) -> Result<()> {
443        let mapping = prepare_semantic_update_fields(fields)?;
444        let client = self.config.connection.client()?;
445        let mut connection = client.get_connection()?;
446        let mut cmd = redis::cmd("HSET");
447        cmd.arg(key);
448        for (field, value) in mapping {
449            cmd.arg(field).arg(value);
450        }
451        let _: usize = cmd.query(&mut connection)?;
452        self.expire_key(key, None)
453    }
454
455    /// Updates cached fields asynchronously for a stored semantic cache entry and refreshes TTL.
456    pub async fn aupdate(&self, key: &str, fields: Map<String, Value>) -> Result<()> {
457        let mapping = prepare_semantic_update_fields(fields)?;
458        let client = self.config.connection.client()?;
459        let mut connection = client.get_multiplexed_async_connection().await?;
460        let mut cmd = redis::cmd("HSET");
461        cmd.arg(key);
462        for (field, value) in mapping {
463            cmd.arg(field).arg(value);
464        }
465        let _: usize = cmd.query_async(&mut connection).await?;
466        self.aexpire_key(key, None).await
467    }
468
469    /// Clears all semantic cache entries while preserving the index.
470    pub fn clear(&self) -> Result<usize> {
471        self.index.clear()
472    }
473
474    /// Clears all semantic cache entries asynchronously while preserving the index.
475    pub async fn aclear(&self) -> Result<usize> {
476        self.async_index().clear().await
477    }
478
479    /// Deletes the semantic cache index and its documents.
480    pub fn delete(&self) -> Result<()> {
481        self.index.delete(true)
482    }
483
484    /// Deletes the semantic cache index asynchronously and its documents.
485    pub async fn adelete(&self) -> Result<()> {
486        self.async_index().delete(true).await
487    }
488
489    /// Drops stored entries by their entry ids.
490    pub fn drop_ids(&self, ids: &[String]) -> Result<()> {
491        let keys = ids.iter().map(|id| self.index.key(id)).collect::<Vec<_>>();
492        self.index.drop_keys(&keys)?;
493        Ok(())
494    }
495
496    /// Drops stored entries by their Redis keys.
497    pub fn drop_keys(&self, keys: &[String]) -> Result<()> {
498        self.index.drop_keys(keys)?;
499        Ok(())
500    }
501
502    /// Drops stored entries asynchronously by their entry ids.
503    pub async fn adrop_ids(&self, ids: &[String]) -> Result<()> {
504        let keys = ids.iter().map(|id| self.index.key(id)).collect::<Vec<_>>();
505        self.async_index().drop_keys(&keys).await?;
506        Ok(())
507    }
508
509    /// Drops stored entries asynchronously by their Redis keys.
510    pub async fn adrop_keys(&self, keys: &[String]) -> Result<()> {
511        self.async_index().drop_keys(keys).await?;
512        Ok(())
513    }
514
515    fn resolve_query_vector(
516        &self,
517        prompt: Option<&str>,
518        vector: Option<&[f32]>,
519    ) -> Result<Vec<f32>> {
520        match (prompt, vector) {
521            (_, Some(vector)) => self.validate_vector(vector),
522            (Some(prompt), None) => self.resolve_vector(prompt, None),
523            (None, None) => Err(crate::Error::InvalidInput(
524                "either prompt or vector must be specified".to_owned(),
525            )),
526        }
527    }
528
529    fn resolve_vector(&self, prompt: &str, vector: Option<&[f32]>) -> Result<Vec<f32>> {
530        match vector {
531            Some(vector) => self.validate_vector(vector),
532            None => {
533                let Some(vectorizer) = &self.vectorizer else {
534                    return Err(crate::Error::InvalidInput(
535                        "a vector or configured vectorizer is required".to_owned(),
536                    ));
537                };
538                let vector = vectorizer.embed(prompt)?;
539                self.validate_vector(&vector)
540            }
541        }
542    }
543
544    fn validate_vector(&self, vector: &[f32]) -> Result<Vec<f32>> {
545        if vector.len() != self.vector_dimensions {
546            return Err(crate::Error::InvalidInput(format!(
547                "vector dimensions mismatch: expected {}, got {}",
548                self.vector_dimensions,
549                vector.len()
550            )));
551        }
552        Ok(vector.to_vec())
553    }
554
555    fn async_index(&self) -> AsyncSearchIndex {
556        AsyncSearchIndex::new(
557            self.index.schema().clone(),
558            self.config.connection.redis_url.clone(),
559        )
560    }
561
562    fn refresh_ttl_sync(&self, hits: &[Map<String, Value>]) -> Result<()> {
563        if self.config.ttl_seconds.is_none() {
564            return Ok(());
565        }
566        for hit in hits {
567            if let Some(key) = hit.get(SEMANTIC_KEY_FIELD).and_then(Value::as_str) {
568                self.expire_key(key, None)?;
569            }
570        }
571        Ok(())
572    }
573
574    async fn refresh_ttl_async(&self, hits: &[Map<String, Value>]) -> Result<()> {
575        if self.config.ttl_seconds.is_none() {
576            return Ok(());
577        }
578        for hit in hits {
579            if let Some(key) = hit.get(SEMANTIC_KEY_FIELD).and_then(Value::as_str) {
580                self.aexpire_key(key, None).await?;
581            }
582        }
583        Ok(())
584    }
585
586    fn expire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
587        if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
588            let client = self.config.connection.client()?;
589            let mut connection = client.get_connection()?;
590            let _: bool = redis::cmd("EXPIRE")
591                .arg(key)
592                .arg(ttl_seconds)
593                .query(&mut connection)?;
594        }
595        Ok(())
596    }
597
598    async fn aexpire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
599        if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
600            let client = self.config.connection.client()?;
601            let mut connection = client.get_multiplexed_async_connection().await?;
602            let _: bool = redis::cmd("EXPIRE")
603                .arg(key)
604                .arg(ttl_seconds)
605                .query_async(&mut connection)
606                .await?;
607        }
608        Ok(())
609    }
610}
611
612impl std::fmt::Debug for SemanticCache {
613    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614        formatter
615            .debug_struct("SemanticCache")
616            .field("config", &self.config)
617            .field("distance_threshold", &self.distance_threshold)
618            .field("vector_dimensions", &self.vector_dimensions)
619            .field("index_name", &self.index.name())
620            .finish()
621    }
622}
623
624/// Redis-backed cache for deterministic content/model embedding lookups.
625#[derive(Debug, Clone)]
626pub struct EmbeddingsCache {
627    /// Cache configuration.
628    pub config: CacheConfig,
629}
630
631impl Default for EmbeddingsCache {
632    fn default() -> Self {
633        Self::new(CacheConfig::default())
634    }
635}
636
637impl EmbeddingsCache {
638    /// Creates a new embeddings cache configuration.
639    pub fn new(config: CacheConfig) -> Self {
640        Self { config }
641    }
642
643    /// Generates the deterministic cache entry id for a content/model pair.
644    pub fn make_entry_id(&self, content: &str, model_name: &str) -> String {
645        hashify(&format!("{content}:{model_name}"))
646    }
647
648    /// Generates the full Redis key for a content/model pair.
649    pub fn make_cache_key(&self, content: &str, model_name: &str) -> String {
650        let entry_id = self.make_entry_id(content, model_name);
651        self.key_for_entry(&entry_id)
652    }
653
654    /// Retrieves a cached embedding by content and model name.
655    pub fn get(&self, content: &str, model_name: &str) -> Result<Option<EmbeddingCacheEntry>> {
656        let key = self.make_cache_key(content, model_name);
657        self.get_by_key(&key)
658    }
659
660    /// Retrieves a cached embedding by its Redis key.
661    pub fn get_by_key(&self, key: &str) -> Result<Option<EmbeddingCacheEntry>> {
662        let client = self.config.connection.client()?;
663        let mut connection = client.get_connection()?;
664        let data: HashMap<String, String> =
665            redis::cmd("HGETALL").arg(key).query(&mut connection)?;
666
667        if data.is_empty() {
668            return Ok(None);
669        }
670
671        self.expire_key(key, None)?;
672        parse_entry(data)
673    }
674
675    /// Retrieves multiple cached embeddings by content and model name.
676    pub fn mget<I, S>(
677        &self,
678        contents: I,
679        model_name: &str,
680    ) -> Result<Vec<Option<EmbeddingCacheEntry>>>
681    where
682        I: IntoIterator<Item = S>,
683        S: AsRef<str>,
684    {
685        let keys = contents
686            .into_iter()
687            .map(|content| self.make_cache_key(content.as_ref(), model_name))
688            .collect::<Vec<_>>();
689        self.mget_by_keys(keys)
690    }
691
692    /// Retrieves multiple cached embeddings by Redis key.
693    pub fn mget_by_keys<I, S>(&self, keys: I) -> Result<Vec<Option<EmbeddingCacheEntry>>>
694    where
695        I: IntoIterator<Item = S>,
696        S: AsRef<str>,
697    {
698        let keys = collect_strings(keys);
699        if keys.is_empty() {
700            return Ok(Vec::new());
701        }
702
703        let mut results = Vec::with_capacity(keys.len());
704        for key in &keys {
705            results.push(self.get_by_key(key)?);
706        }
707        Ok(results)
708    }
709
710    /// Stores a cached embedding and returns its Redis key.
711    pub fn set(
712        &self,
713        content: &str,
714        model_name: &str,
715        embedding: &[f32],
716        metadata: Option<Value>,
717        ttl_seconds: Option<u64>,
718    ) -> Result<String> {
719        let entry = self.prepare_entry(content, model_name, embedding, metadata);
720        let key = self.key_for_entry(&entry.entry_id);
721        self.write_entry(&key, &entry)?;
722        self.expire_key(&key, ttl_seconds)?;
723        Ok(key)
724    }
725
726    /// Stores multiple cached embeddings and returns their Redis keys.
727    pub fn mset(
728        &self,
729        items: &[EmbeddingCacheItem],
730        ttl_seconds: Option<u64>,
731    ) -> Result<Vec<String>> {
732        let mut keys = Vec::with_capacity(items.len());
733        for item in items {
734            let key = self.set(
735                &item.content,
736                &item.model_name,
737                &item.embedding,
738                item.metadata.clone(),
739                ttl_seconds,
740            )?;
741            keys.push(key);
742        }
743        Ok(keys)
744    }
745
746    /// Checks whether a cached embedding exists for a content/model pair.
747    pub fn exists(&self, content: &str, model_name: &str) -> Result<bool> {
748        let key = self.make_cache_key(content, model_name);
749        self.exists_by_key(&key)
750    }
751
752    /// Checks whether a cached embedding exists for a Redis key.
753    pub fn exists_by_key(&self, key: &str) -> Result<bool> {
754        let client = self.config.connection.client()?;
755        let mut connection = client.get_connection()?;
756        let exists: u64 = redis::cmd("EXISTS").arg(key).query(&mut connection)?;
757        Ok(exists > 0)
758    }
759
760    /// Checks whether multiple cached embeddings exist for content/model pairs.
761    pub fn mexists<I, S>(&self, contents: I, model_name: &str) -> Result<Vec<bool>>
762    where
763        I: IntoIterator<Item = S>,
764        S: AsRef<str>,
765    {
766        let keys = contents
767            .into_iter()
768            .map(|content| self.make_cache_key(content.as_ref(), model_name))
769            .collect::<Vec<_>>();
770        self.mexists_by_keys(keys)
771    }
772
773    /// Checks whether multiple cached embeddings exist for Redis keys.
774    pub fn mexists_by_keys<I, S>(&self, keys: I) -> Result<Vec<bool>>
775    where
776        I: IntoIterator<Item = S>,
777        S: AsRef<str>,
778    {
779        let keys = collect_strings(keys);
780        if keys.is_empty() {
781            return Ok(Vec::new());
782        }
783
784        let client = self.config.connection.client()?;
785        let mut connection = client.get_connection()?;
786        let mut results = Vec::with_capacity(keys.len());
787        for key in keys {
788            let exists: u64 = redis::cmd("EXISTS").arg(key).query(&mut connection)?;
789            results.push(exists > 0);
790        }
791        Ok(results)
792    }
793
794    /// Removes a cached embedding by content and model name.
795    pub fn drop(&self, content: &str, model_name: &str) -> Result<()> {
796        let key = self.make_cache_key(content, model_name);
797        self.drop_by_key(&key)
798    }
799
800    /// Removes a cached embedding by Redis key.
801    pub fn drop_by_key(&self, key: &str) -> Result<()> {
802        let client = self.config.connection.client()?;
803        let mut connection = client.get_connection()?;
804        let _: usize = redis::cmd("DEL").arg(key).query(&mut connection)?;
805        Ok(())
806    }
807
808    /// Removes multiple cached embeddings by content and model name.
809    pub fn mdrop<I, S>(&self, contents: I, model_name: &str) -> Result<()>
810    where
811        I: IntoIterator<Item = S>,
812        S: AsRef<str>,
813    {
814        let keys = contents
815            .into_iter()
816            .map(|content| self.make_cache_key(content.as_ref(), model_name))
817            .collect::<Vec<_>>();
818        self.mdrop_by_keys(keys)
819    }
820
821    /// Removes multiple cached embeddings by Redis key.
822    pub fn mdrop_by_keys<I, S>(&self, keys: I) -> Result<()>
823    where
824        I: IntoIterator<Item = S>,
825        S: AsRef<str>,
826    {
827        let keys = collect_strings(keys);
828        if keys.is_empty() {
829            return Ok(());
830        }
831
832        let client = self.config.connection.client()?;
833        let mut connection = client.get_connection()?;
834        let _: usize = redis::cmd("DEL").arg(keys).query(&mut connection)?;
835        Ok(())
836    }
837
838    /// Clears every cache entry under this cache namespace.
839    pub fn clear(&self) -> Result<usize> {
840        let keys = self.all_keys()?;
841        if keys.is_empty() {
842            return Ok(0);
843        }
844
845        let count = keys.len();
846        self.mdrop_by_keys(keys)?;
847        Ok(count)
848    }
849
850    /// Retrieves a cached embedding by content and model name asynchronously.
851    pub async fn aget(
852        &self,
853        content: &str,
854        model_name: &str,
855    ) -> Result<Option<EmbeddingCacheEntry>> {
856        let key = self.make_cache_key(content, model_name);
857        self.aget_by_key(&key).await
858    }
859
860    /// Retrieves a cached embedding by its Redis key asynchronously.
861    pub async fn aget_by_key(&self, key: &str) -> Result<Option<EmbeddingCacheEntry>> {
862        let client = self.config.connection.client()?;
863        let mut connection = client.get_multiplexed_async_connection().await?;
864        let data: HashMap<String, String> = redis::cmd("HGETALL")
865            .arg(key)
866            .query_async(&mut connection)
867            .await?;
868
869        if data.is_empty() {
870            return Ok(None);
871        }
872
873        self.aexpire_key(key, None).await?;
874        parse_entry(data)
875    }
876
877    /// Retrieves multiple cached embeddings by content and model name asynchronously.
878    pub async fn amget<I, S>(
879        &self,
880        contents: I,
881        model_name: &str,
882    ) -> Result<Vec<Option<EmbeddingCacheEntry>>>
883    where
884        I: IntoIterator<Item = S>,
885        S: AsRef<str>,
886    {
887        let keys = contents
888            .into_iter()
889            .map(|content| self.make_cache_key(content.as_ref(), model_name))
890            .collect::<Vec<_>>();
891        self.amget_by_keys(keys).await
892    }
893
894    /// Retrieves multiple cached embeddings by Redis key asynchronously.
895    pub async fn amget_by_keys<I, S>(&self, keys: I) -> Result<Vec<Option<EmbeddingCacheEntry>>>
896    where
897        I: IntoIterator<Item = S>,
898        S: AsRef<str>,
899    {
900        let keys = collect_strings(keys);
901        if keys.is_empty() {
902            return Ok(Vec::new());
903        }
904
905        let mut results = Vec::with_capacity(keys.len());
906        for key in &keys {
907            results.push(self.aget_by_key(key).await?);
908        }
909        Ok(results)
910    }
911
912    /// Stores a cached embedding asynchronously and returns its Redis key.
913    pub async fn aset(
914        &self,
915        content: &str,
916        model_name: &str,
917        embedding: &[f32],
918        metadata: Option<Value>,
919        ttl_seconds: Option<u64>,
920    ) -> Result<String> {
921        let entry = self.prepare_entry(content, model_name, embedding, metadata);
922        let key = self.key_for_entry(&entry.entry_id);
923        self.awrite_entry(&key, &entry).await?;
924        self.aexpire_key(&key, ttl_seconds).await?;
925        Ok(key)
926    }
927
928    /// Stores multiple cached embeddings asynchronously and returns their Redis keys.
929    pub async fn amset(
930        &self,
931        items: &[EmbeddingCacheItem],
932        ttl_seconds: Option<u64>,
933    ) -> Result<Vec<String>> {
934        let mut keys = Vec::with_capacity(items.len());
935        for item in items {
936            let key = self
937                .aset(
938                    &item.content,
939                    &item.model_name,
940                    &item.embedding,
941                    item.metadata.clone(),
942                    ttl_seconds,
943                )
944                .await?;
945            keys.push(key);
946        }
947        Ok(keys)
948    }
949
950    /// Checks whether a cached embedding exists for a content/model pair asynchronously.
951    pub async fn aexists(&self, content: &str, model_name: &str) -> Result<bool> {
952        let key = self.make_cache_key(content, model_name);
953        self.aexists_by_key(&key).await
954    }
955
956    /// Checks whether a cached embedding exists for a Redis key asynchronously.
957    pub async fn aexists_by_key(&self, key: &str) -> Result<bool> {
958        let client = self.config.connection.client()?;
959        let mut connection = client.get_multiplexed_async_connection().await?;
960        Ok(connection.exists(key).await?)
961    }
962
963    /// Checks whether multiple cached embeddings exist for content/model pairs asynchronously.
964    pub async fn amexists<I, S>(&self, contents: I, model_name: &str) -> Result<Vec<bool>>
965    where
966        I: IntoIterator<Item = S>,
967        S: AsRef<str>,
968    {
969        let keys = contents
970            .into_iter()
971            .map(|content| self.make_cache_key(content.as_ref(), model_name))
972            .collect::<Vec<_>>();
973        self.amexists_by_keys(keys).await
974    }
975
976    /// Checks whether multiple cached embeddings exist for Redis keys asynchronously.
977    pub async fn amexists_by_keys<I, S>(&self, keys: I) -> Result<Vec<bool>>
978    where
979        I: IntoIterator<Item = S>,
980        S: AsRef<str>,
981    {
982        let keys = collect_strings(keys);
983        if keys.is_empty() {
984            return Ok(Vec::new());
985        }
986
987        let client = self.config.connection.client()?;
988        let mut connection = client.get_multiplexed_async_connection().await?;
989        let mut results = Vec::with_capacity(keys.len());
990        for key in keys {
991            results.push(connection.exists(key).await?);
992        }
993        Ok(results)
994    }
995
996    /// Removes a cached embedding by content and model name asynchronously.
997    pub async fn adrop(&self, content: &str, model_name: &str) -> Result<()> {
998        let key = self.make_cache_key(content, model_name);
999        self.adrop_by_key(&key).await
1000    }
1001
1002    /// Removes a cached embedding by Redis key asynchronously.
1003    pub async fn adrop_by_key(&self, key: &str) -> Result<()> {
1004        let client = self.config.connection.client()?;
1005        let mut connection = client.get_multiplexed_async_connection().await?;
1006        let _: usize = connection.del(key).await?;
1007        Ok(())
1008    }
1009
1010    /// Removes multiple cached embeddings by content and model name asynchronously.
1011    pub async fn amdrop<I, S>(&self, contents: I, model_name: &str) -> Result<()>
1012    where
1013        I: IntoIterator<Item = S>,
1014        S: AsRef<str>,
1015    {
1016        let keys = contents
1017            .into_iter()
1018            .map(|content| self.make_cache_key(content.as_ref(), model_name))
1019            .collect::<Vec<_>>();
1020        self.amdrop_by_keys(keys).await
1021    }
1022
1023    /// Removes multiple cached embeddings by Redis key asynchronously.
1024    pub async fn amdrop_by_keys<I, S>(&self, keys: I) -> Result<()>
1025    where
1026        I: IntoIterator<Item = S>,
1027        S: AsRef<str>,
1028    {
1029        let keys = collect_strings(keys);
1030        if keys.is_empty() {
1031            return Ok(());
1032        }
1033
1034        let client = self.config.connection.client()?;
1035        let mut connection = client.get_multiplexed_async_connection().await?;
1036        let _: usize = connection.del(keys).await?;
1037        Ok(())
1038    }
1039
1040    /// Clears every cache entry under this cache namespace asynchronously.
1041    pub async fn aclear(&self) -> Result<usize> {
1042        let keys = self.aall_keys().await?;
1043        if keys.is_empty() {
1044            return Ok(0);
1045        }
1046
1047        let count = keys.len();
1048        self.amdrop_by_keys(keys).await?;
1049        Ok(count)
1050    }
1051
1052    fn prepare_entry(
1053        &self,
1054        content: &str,
1055        model_name: &str,
1056        embedding: &[f32],
1057        metadata: Option<Value>,
1058    ) -> EmbeddingCacheEntry {
1059        EmbeddingCacheEntry {
1060            entry_id: self.make_entry_id(content, model_name),
1061            content: content.to_owned(),
1062            model_name: model_name.to_owned(),
1063            embedding: embedding.to_vec(),
1064            metadata,
1065        }
1066    }
1067
1068    fn write_entry(&self, key: &str, entry: &EmbeddingCacheEntry) -> Result<()> {
1069        let client = self.config.connection.client()?;
1070        let mut connection = client.get_connection()?;
1071        let mut cmd = redis::cmd("HSET");
1072        cmd.arg(key)
1073            .arg("entry_id")
1074            .arg(&entry.entry_id)
1075            .arg("content")
1076            .arg(&entry.content)
1077            .arg("model_name")
1078            .arg(&entry.model_name)
1079            .arg("embedding")
1080            .arg(serde_json::to_string(&entry.embedding)?);
1081
1082        if let Some(metadata) = &entry.metadata {
1083            cmd.arg("metadata").arg(serde_json::to_string(metadata)?);
1084        }
1085
1086        let _: usize = cmd.query(&mut connection)?;
1087        Ok(())
1088    }
1089
1090    async fn awrite_entry(&self, key: &str, entry: &EmbeddingCacheEntry) -> Result<()> {
1091        let client = self.config.connection.client()?;
1092        let mut connection = client.get_multiplexed_async_connection().await?;
1093        let mut cmd = redis::cmd("HSET");
1094        cmd.arg(key)
1095            .arg("entry_id")
1096            .arg(&entry.entry_id)
1097            .arg("content")
1098            .arg(&entry.content)
1099            .arg("model_name")
1100            .arg(&entry.model_name)
1101            .arg("embedding")
1102            .arg(serde_json::to_string(&entry.embedding)?);
1103
1104        if let Some(metadata) = &entry.metadata {
1105            cmd.arg("metadata").arg(serde_json::to_string(metadata)?);
1106        }
1107
1108        let _: usize = cmd.query_async(&mut connection).await?;
1109        Ok(())
1110    }
1111
1112    fn expire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
1113        if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
1114            let client = self.config.connection.client()?;
1115            let mut connection = client.get_connection()?;
1116            let _: bool = redis::cmd("EXPIRE")
1117                .arg(key)
1118                .arg(ttl_seconds)
1119                .query(&mut connection)?;
1120        }
1121        Ok(())
1122    }
1123
1124    async fn aexpire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
1125        if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
1126            let client = self.config.connection.client()?;
1127            let mut connection = client.get_multiplexed_async_connection().await?;
1128            let _: bool = redis::cmd("EXPIRE")
1129                .arg(key)
1130                .arg(ttl_seconds)
1131                .query_async(&mut connection)
1132                .await?;
1133        }
1134        Ok(())
1135    }
1136
1137    fn all_keys(&self) -> Result<Vec<String>> {
1138        let client = self.config.connection.client()?;
1139        let mut connection = client.get_connection()?;
1140        let keys: Vec<String> = redis::cmd("KEYS")
1141            .arg(format!("{}:*", self.config.name))
1142            .query(&mut connection)?;
1143        Ok(keys)
1144    }
1145
1146    async fn aall_keys(&self) -> Result<Vec<String>> {
1147        let client = self.config.connection.client()?;
1148        let mut connection = client.get_multiplexed_async_connection().await?;
1149        let keys: Vec<String> = redis::cmd("KEYS")
1150            .arg(format!("{}:*", self.config.name))
1151            .query_async(&mut connection)
1152            .await?;
1153        Ok(keys)
1154    }
1155
1156    fn key_for_entry(&self, entry_id: &str) -> String {
1157        format!("{}:{entry_id}", self.config.name)
1158    }
1159}
1160
1161fn collect_strings<I, S>(values: I) -> Vec<String>
1162where
1163    I: IntoIterator<Item = S>,
1164    S: AsRef<str>,
1165{
1166    values
1167        .into_iter()
1168        .map(|value| value.as_ref().to_owned())
1169        .collect()
1170}
1171
1172fn parse_entry(data: HashMap<String, String>) -> Result<Option<EmbeddingCacheEntry>> {
1173    if data.is_empty() {
1174        return Ok(None);
1175    }
1176
1177    let entry = EmbeddingCacheEntry {
1178        entry_id: data.get("entry_id").cloned().unwrap_or_default(),
1179        content: data.get("content").cloned().unwrap_or_default(),
1180        model_name: data.get("model_name").cloned().unwrap_or_default(),
1181        embedding: match data.get("embedding") {
1182            Some(value) => serde_json::from_str::<Vec<f32>>(value)?,
1183            None => Vec::new(),
1184        },
1185        metadata: data
1186            .get("metadata")
1187            .map(|value| serde_json::from_str::<Value>(value))
1188            .transpose()?,
1189    };
1190
1191    Ok(Some(entry))
1192}
1193
1194fn hashify(content: &str) -> String {
1195    let mut hasher = Sha256::new();
1196    hasher.update(content.as_bytes());
1197    let digest = hasher.finalize();
1198    let mut output = String::with_capacity(digest.len() * 2);
1199    for byte in digest {
1200        use std::fmt::Write as _;
1201        let _ = write!(&mut output, "{byte:02x}");
1202    }
1203    output
1204}
1205
1206fn semantic_cache_schema(
1207    name: &str,
1208    vector_dimensions: usize,
1209    dtype: VectorDataType,
1210    filterable_fields: &[Value],
1211) -> Value {
1212    let mut fields = vec![
1213        json!({ "name": SEMANTIC_ENTRY_ID_FIELD, "type": "tag" }),
1214        json!({ "name": SEMANTIC_PROMPT_FIELD, "type": "text" }),
1215        json!({ "name": SEMANTIC_RESPONSE_FIELD, "type": "text" }),
1216        json!({ "name": SEMANTIC_INSERTED_AT_FIELD, "type": "numeric" }),
1217        json!({ "name": SEMANTIC_UPDATED_AT_FIELD, "type": "numeric" }),
1218        json!({ "name": SEMANTIC_METADATA_FIELD, "type": "text" }),
1219        json!({
1220            "name": SEMANTIC_VECTOR_FIELD,
1221            "type": "vector",
1222            "attrs": {
1223                "algorithm": "flat",
1224                "dims": vector_dimensions,
1225                "datatype": dtype.as_str(),
1226                "distance_metric": "cosine"
1227            }
1228        }),
1229    ];
1230    fields.extend(filterable_fields.iter().cloned());
1231    json!({
1232        "index": {
1233            "name": name,
1234            "prefix": name,
1235            "storage_type": "hash",
1236        },
1237        "fields": fields,
1238    })
1239}
1240
1241fn default_semantic_return_fields() -> Vec<String> {
1242    vec![
1243        SEMANTIC_ENTRY_ID_FIELD.to_owned(),
1244        SEMANTIC_PROMPT_FIELD.to_owned(),
1245        SEMANTIC_RESPONSE_FIELD.to_owned(),
1246        "vector_distance".to_owned(),
1247        SEMANTIC_INSERTED_AT_FIELD.to_owned(),
1248        SEMANTIC_UPDATED_AT_FIELD.to_owned(),
1249        SEMANTIC_METADATA_FIELD.to_owned(),
1250    ]
1251}
1252
1253fn current_timestamp() -> f64 {
1254    Utc::now().timestamp_millis() as f64 / 1000.0
1255}
1256
1257fn semantic_entry_id(prompt: &str, filters: Option<&Map<String, Value>>) -> String {
1258    if let Some(filters) = filters {
1259        let mut parts = filters
1260            .iter()
1261            .map(|(key, value)| format!("{key}{}", value_to_hash_string(value)))
1262            .collect::<Vec<_>>();
1263        parts.sort();
1264        hashify(&format!("{prompt}{}", parts.join("")))
1265    } else {
1266        hashify(prompt)
1267    }
1268}
1269
1270fn value_to_hash_string(value: &Value) -> String {
1271    match value {
1272        Value::Null => "null".to_owned(),
1273        Value::Bool(value) => value.to_string(),
1274        Value::Number(value) => value.to_string(),
1275        Value::String(value) => value.clone(),
1276        Value::Array(_) | Value::Object(_) => serde_json::to_string(value).unwrap_or_default(),
1277    }
1278}
1279
1280/// Reserved field names that cannot be used as filterable field names.
1281const RESERVED_SEMANTIC_FIELDS: &[&str] = &[
1282    SEMANTIC_ENTRY_ID_FIELD,
1283    SEMANTIC_PROMPT_FIELD,
1284    SEMANTIC_RESPONSE_FIELD,
1285    SEMANTIC_VECTOR_FIELD,
1286    SEMANTIC_INSERTED_AT_FIELD,
1287    SEMANTIC_UPDATED_AT_FIELD,
1288    SEMANTIC_METADATA_FIELD,
1289    SEMANTIC_KEY_FIELD,
1290    "vector_distance",
1291];
1292
1293fn validate_filterable_fields(fields: &[Value]) -> Result<()> {
1294    let mut seen = std::collections::HashSet::new();
1295    for field in fields {
1296        let name = field
1297            .get("name")
1298            .and_then(Value::as_str)
1299            .unwrap_or_default();
1300        let field_type = field
1301            .get("type")
1302            .and_then(Value::as_str)
1303            .unwrap_or_default();
1304
1305        if name.is_empty() {
1306            return Err(crate::Error::InvalidInput(
1307                "filterable field must have a non-empty 'name'".to_owned(),
1308            ));
1309        }
1310
1311        if RESERVED_SEMANTIC_FIELDS.contains(&name) {
1312            return Err(crate::Error::InvalidInput(format!(
1313                "{name} is a reserved field name for the semantic cache schema"
1314            )));
1315        }
1316
1317        if !seen.insert(name.to_owned()) {
1318            return Err(crate::Error::InvalidInput(format!(
1319                "duplicate field name: {name}. Field names must be unique"
1320            )));
1321        }
1322
1323        if !matches!(field_type, "tag" | "text" | "numeric" | "geo") {
1324            return Err(crate::Error::InvalidInput(format!(
1325                "invalid filterable field type: '{field_type}' for field '{name}'"
1326            )));
1327        }
1328    }
1329    Ok(())
1330}
1331
1332fn validate_distance_threshold(distance_threshold: f32) -> Result<()> {
1333    if !(0.0..=2.0).contains(&distance_threshold) {
1334        return Err(crate::Error::InvalidInput(format!(
1335            "distance threshold must be between 0 and 2, got {distance_threshold}"
1336        )));
1337    }
1338    Ok(())
1339}
1340
1341fn validate_metadata(metadata: &Value) -> Result<()> {
1342    if !metadata.is_object() {
1343        return Err(crate::Error::InvalidInput(
1344            "metadata must be a JSON object".to_owned(),
1345        ));
1346    }
1347    Ok(())
1348}
1349
1350fn query_output_documents(output: QueryOutput) -> Result<Vec<Map<String, Value>>> {
1351    match output {
1352        QueryOutput::Documents(documents) => Ok(documents),
1353        QueryOutput::Count(_) => Err(crate::Error::InvalidInput(
1354            "semantic cache queries must return documents".to_owned(),
1355        )),
1356    }
1357}
1358
1359fn process_semantic_hits(
1360    documents: Vec<Map<String, Value>>,
1361    return_fields: Option<&[&str]>,
1362) -> Result<Vec<Map<String, Value>>> {
1363    let selected = return_fields.map(|fields| {
1364        fields
1365            .iter()
1366            .map(|field| (*field).to_owned())
1367            .collect::<std::collections::HashSet<_>>()
1368    });
1369    let mut hits = Vec::with_capacity(documents.len());
1370    for mut document in documents {
1371        let key = document
1372            .remove("id")
1373            .unwrap_or_else(|| Value::String(String::new()));
1374        let mut hit = Map::new();
1375        hit.insert(SEMANTIC_KEY_FIELD.to_owned(), key);
1376        for (field, value) in document {
1377            let include = selected
1378                .as_ref()
1379                .is_none_or(|fields| fields.contains(&field));
1380            if !include {
1381                continue;
1382            }
1383            hit.insert(field.clone(), normalize_semantic_value(&field, value)?);
1384        }
1385        hits.push(hit);
1386    }
1387    Ok(hits)
1388}
1389
1390fn normalize_semantic_value(field: &str, value: Value) -> Result<Value> {
1391    match (field, value) {
1392        (SEMANTIC_METADATA_FIELD, Value::String(value)) => {
1393            Ok(serde_json::from_str(&value).unwrap_or(Value::String(value)))
1394        }
1395        (
1396            "vector_distance" | SEMANTIC_INSERTED_AT_FIELD | SEMANTIC_UPDATED_AT_FIELD,
1397            Value::String(value),
1398        ) => {
1399            let parsed = value.parse::<f64>().map_err(|_| {
1400                crate::Error::InvalidInput(format!("could not parse numeric field '{field}'"))
1401            })?;
1402            Ok(number_value(parsed))
1403        }
1404        (_, value) => Ok(value),
1405    }
1406}
1407
1408fn prepare_semantic_update_fields(fields: Map<String, Value>) -> Result<Vec<(String, String)>> {
1409    let mut mapping = Vec::with_capacity(fields.len() + 1);
1410    for (field, value) in fields {
1411        if field == SEMANTIC_VECTOR_FIELD {
1412            return Err(crate::Error::InvalidInput(
1413                "updating the stored vector is not supported yet".to_owned(),
1414            ));
1415        }
1416        if field == SEMANTIC_METADATA_FIELD {
1417            validate_metadata(&value)?;
1418        }
1419        let serialized = match value {
1420            Value::Null => "null".to_owned(),
1421            Value::Bool(value) => value.to_string(),
1422            Value::Number(value) => value.to_string(),
1423            Value::String(value) => value,
1424            Value::Array(_) | Value::Object(_) => serde_json::to_string(&value)?,
1425        };
1426        mapping.push((field, serialized));
1427    }
1428    mapping.push((
1429        SEMANTIC_UPDATED_AT_FIELD.to_owned(),
1430        current_timestamp().to_string(),
1431    ));
1432    Ok(mapping)
1433}
1434
1435fn number_value(value: f64) -> Value {
1436    Number::from_f64(value)
1437        .map(Value::Number)
1438        .unwrap_or(Value::Null)
1439}
1440
1441#[cfg(test)]
1442mod tests {
1443    use serde_json::json;
1444
1445    use super::{
1446        CacheConfig, EmbeddingsCache, hashify, validate_distance_threshold,
1447        validate_filterable_fields, validate_metadata,
1448    };
1449
1450    #[test]
1451    fn hashify_matches_expected_sha256() {
1452        assert_eq!(
1453            hashify("Hello world:text-embedding-ada-002"),
1454            "368dacc611e96e4189a9809faaca1a70b3c3306352bbcfc9ab6291359a5dfca0"
1455        );
1456    }
1457
1458    #[test]
1459    fn cache_key_is_stable() {
1460        let cache = EmbeddingsCache::new(CacheConfig::default());
1461        let key = cache.make_cache_key("Hello world", "text-embedding-ada-002");
1462        assert_eq!(
1463            key,
1464            "embedcache:368dacc611e96e4189a9809faaca1a70b3c3306352bbcfc9ab6291359a5dfca0"
1465        );
1466    }
1467
1468    #[test]
1469    fn entry_id_is_deterministic() {
1470        let cache = EmbeddingsCache::new(CacheConfig::default());
1471        let id1 = cache.make_entry_id("Hello world", "text-embedding-ada-002");
1472        let id2 = cache.make_entry_id("Hello world", "text-embedding-ada-002");
1473        assert_eq!(id1, id2);
1474
1475        let different = cache.make_entry_id("Different text", "text-embedding-ada-002");
1476        assert_ne!(id1, different);
1477    }
1478
1479    #[test]
1480    fn entry_id_different_inputs_differ() {
1481        let cache = EmbeddingsCache::new(CacheConfig::default());
1482        let id_a = cache.make_entry_id("What is machine learning?", "text-embedding-ada-002");
1483        let id_b = cache.make_entry_id("How do neural networks work?", "text-embedding-ada-002");
1484        assert_ne!(id_a, id_b);
1485    }
1486
1487    #[test]
1488    fn cache_key_includes_cache_name() {
1489        let cache_a = EmbeddingsCache::new(CacheConfig::new("cache_a", "redis://localhost:6379"));
1490        let cache_b = EmbeddingsCache::new(CacheConfig::new("cache_b", "redis://localhost:6379"));
1491        let key_a = cache_a.make_cache_key("hello", "model");
1492        let key_b = cache_b.make_cache_key("hello", "model");
1493        assert!(key_a.starts_with("cache_a:"));
1494        assert!(key_b.starts_with("cache_b:"));
1495        assert_ne!(key_a, key_b);
1496    }
1497
1498    #[test]
1499    fn distance_threshold_out_of_range() {
1500        assert!(validate_distance_threshold(-1.0).is_err());
1501        assert!(validate_distance_threshold(2.5).is_err());
1502        assert!(validate_distance_threshold(0.0).is_ok());
1503        assert!(validate_distance_threshold(1.0).is_ok());
1504        assert!(validate_distance_threshold(2.0).is_ok());
1505    }
1506
1507    #[test]
1508    fn metadata_must_be_object() {
1509        assert!(validate_metadata(&json!("string")).is_err());
1510        assert!(validate_metadata(&json!([1, 2])).is_err());
1511        assert!(validate_metadata(&json!(42)).is_err());
1512        assert!(validate_metadata(&json!({"key": "value"})).is_ok());
1513        assert!(validate_metadata(&json!({})).is_ok());
1514    }
1515
1516    #[test]
1517    fn filterable_fields_reserved_name() {
1518        let fields = vec![json!({"name": "metadata", "type": "tag"})];
1519        let err = validate_filterable_fields(&fields).unwrap_err();
1520        assert!(err.to_string().contains("reserved"));
1521    }
1522
1523    #[test]
1524    fn filterable_fields_duplicate_name() {
1525        let fields = vec![
1526            json!({"name": "label", "type": "tag"}),
1527            json!({"name": "label", "type": "tag"}),
1528        ];
1529        let err = validate_filterable_fields(&fields).unwrap_err();
1530        assert!(err.to_string().contains("duplicate"));
1531    }
1532
1533    #[test]
1534    fn filterable_fields_invalid_type() {
1535        let fields = vec![
1536            json!({"name": "label", "type": "tag"}),
1537            json!({"name": "test", "type": "nothing"}),
1538        ];
1539        let err = validate_filterable_fields(&fields).unwrap_err();
1540        assert!(err.to_string().contains("invalid"));
1541    }
1542
1543    #[test]
1544    fn filterable_fields_valid() {
1545        let fields = vec![
1546            json!({"name": "label", "type": "tag"}),
1547            json!({"name": "score", "type": "numeric"}),
1548        ];
1549        assert!(validate_filterable_fields(&fields).is_ok());
1550    }
1551
1552    #[test]
1553    fn default_embeddings_cache_name() {
1554        let cache = EmbeddingsCache::default();
1555        assert_eq!(cache.config.name, "embedcache");
1556        assert!(cache.config.ttl_seconds.is_none());
1557    }
1558
1559    #[test]
1560    fn custom_embeddings_cache_config() {
1561        let config = CacheConfig::new("custom_cache", "redis://localhost:6379").with_ttl(60);
1562        let cache = EmbeddingsCache::new(config);
1563        assert_eq!(cache.config.name, "custom_cache");
1564        assert_eq!(cache.config.ttl_seconds, Some(60));
1565    }
1566
1567    #[test]
1568    fn semantic_cache_schema_respects_dtype() {
1569        use super::{VectorDataType, semantic_cache_schema};
1570
1571        let schema_f32 = semantic_cache_schema("test", 128, VectorDataType::Float32, &[]);
1572        let vec_field = schema_f32["fields"]
1573            .as_array()
1574            .unwrap()
1575            .iter()
1576            .find(|f| f["name"] == "prompt_vector")
1577            .unwrap();
1578        assert_eq!(vec_field["attrs"]["datatype"], "float32");
1579
1580        let schema_f64 = semantic_cache_schema("test", 128, VectorDataType::Float64, &[]);
1581        let vec_field = schema_f64["fields"]
1582            .as_array()
1583            .unwrap()
1584            .iter()
1585            .find(|f| f["name"] == "prompt_vector")
1586            .unwrap();
1587        assert_eq!(vec_field["attrs"]["datatype"], "float64");
1588
1589        let schema_bfloat16 = semantic_cache_schema("test", 128, VectorDataType::Bfloat16, &[]);
1590        let vec_field = schema_bfloat16["fields"]
1591            .as_array()
1592            .unwrap()
1593            .iter()
1594            .find(|f| f["name"] == "prompt_vector")
1595            .unwrap();
1596        assert_eq!(vec_field["attrs"]["datatype"], "bfloat16");
1597
1598        let schema_float16 = semantic_cache_schema("test", 128, VectorDataType::Float16, &[]);
1599        let vec_field = schema_float16["fields"]
1600            .as_array()
1601            .unwrap()
1602            .iter()
1603            .find(|f| f["name"] == "prompt_vector")
1604            .unwrap();
1605        assert_eq!(vec_field["attrs"]["datatype"], "float16");
1606    }
1607}