1use std::{collections::HashMap, sync::Arc};
13
14use chrono::Utc;
15use redis::AsyncCommands;
16use serde::{Deserialize, Serialize};
17use serde_json::{Map, Number, Value, json};
18use sha2::{Digest, Sha256};
19
20use crate::{
21 error::Result,
22 filter::FilterExpression,
23 index::{AsyncSearchIndex, QueryOutput, RedisConnectionInfo, SearchIndex},
24 query::{Vector, VectorRangeQuery},
25 schema::VectorDataType,
26 vectorizers::Vectorizer,
27};
28
29const SEMANTIC_ENTRY_ID_FIELD: &str = "entry_id";
30const SEMANTIC_PROMPT_FIELD: &str = "prompt";
31const SEMANTIC_RESPONSE_FIELD: &str = "response";
32const SEMANTIC_VECTOR_FIELD: &str = "prompt_vector";
33const SEMANTIC_INSERTED_AT_FIELD: &str = "inserted_at";
34const SEMANTIC_UPDATED_AT_FIELD: &str = "updated_at";
35const SEMANTIC_METADATA_FIELD: &str = "metadata";
36const SEMANTIC_KEY_FIELD: &str = "key";
37
38#[derive(Debug, Clone)]
40pub struct CacheConfig {
41 pub name: String,
43 pub connection: RedisConnectionInfo,
45 pub ttl_seconds: Option<u64>,
47}
48
49impl CacheConfig {
50 pub fn new(name: impl Into<String>, redis_url: impl Into<String>) -> Self {
52 Self {
53 name: name.into(),
54 connection: RedisConnectionInfo::new(redis_url),
55 ttl_seconds: None,
56 }
57 }
58
59 #[must_use]
61 pub fn with_ttl(mut self, ttl_seconds: u64) -> Self {
62 self.ttl_seconds = Some(ttl_seconds);
63 self
64 }
65}
66
67impl Default for CacheConfig {
68 fn default() -> Self {
69 Self::new("embedcache", "redis://127.0.0.1:6379")
70 }
71}
72
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct EmbeddingCacheEntry {
76 pub entry_id: String,
78 pub content: String,
80 pub model_name: String,
82 pub embedding: Vec<f32>,
84 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub metadata: Option<Value>,
87}
88
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
91pub struct EmbeddingCacheItem {
92 pub content: String,
94 pub model_name: String,
96 pub embedding: Vec<f32>,
98 #[serde(default, skip_serializing_if = "Option::is_none")]
100 pub metadata: Option<Value>,
101}
102
103#[derive(Clone)]
105pub struct SemanticCache {
106 pub config: CacheConfig,
108 pub distance_threshold: f32,
110 pub vector_dimensions: usize,
112 pub dtype: VectorDataType,
114 pub index: SearchIndex,
116 vectorizer: Option<Arc<dyn Vectorizer>>,
117 return_fields: Vec<String>,
118}
119
120impl SemanticCache {
121 pub fn new(
123 config: CacheConfig,
124 distance_threshold: f32,
125 vector_dimensions: usize,
126 ) -> Result<Self> {
127 Self::with_options(
128 config,
129 distance_threshold,
130 vector_dimensions,
131 VectorDataType::Float32,
132 &[],
133 )
134 }
135
136 pub fn with_dtype(
138 config: CacheConfig,
139 distance_threshold: f32,
140 vector_dimensions: usize,
141 dtype: VectorDataType,
142 ) -> Result<Self> {
143 Self::with_options(config, distance_threshold, vector_dimensions, dtype, &[])
144 }
145
146 pub fn with_filterable_fields(
148 config: CacheConfig,
149 distance_threshold: f32,
150 vector_dimensions: usize,
151 filterable_fields: &[Value],
152 ) -> Result<Self> {
153 Self::with_options(
154 config,
155 distance_threshold,
156 vector_dimensions,
157 VectorDataType::Float32,
158 filterable_fields,
159 )
160 }
161
162 pub fn with_options(
164 config: CacheConfig,
165 distance_threshold: f32,
166 vector_dimensions: usize,
167 dtype: VectorDataType,
168 filterable_fields: &[Value],
169 ) -> Result<Self> {
170 validate_distance_threshold(distance_threshold)?;
171 if vector_dimensions == 0 {
172 return Err(crate::Error::InvalidInput(
173 "vector_dimensions must be greater than zero".to_owned(),
174 ));
175 }
176 validate_filterable_fields(filterable_fields)?;
177
178 let schema =
179 semantic_cache_schema(&config.name, vector_dimensions, dtype, filterable_fields);
180 let index = SearchIndex::from_json_value(schema, config.connection.redis_url.clone())?;
181 if !index.exists().unwrap_or(false) {
182 index.create_with_options(false, false)?;
183 }
184
185 Ok(Self {
186 config,
187 distance_threshold,
188 vector_dimensions,
189 dtype,
190 index,
191 vectorizer: None,
192 return_fields: default_semantic_return_fields(),
193 })
194 }
195
196 #[must_use]
198 pub fn with_vectorizer<V>(mut self, vectorizer: V) -> Self
199 where
200 V: Vectorizer + 'static,
201 {
202 self.vectorizer = Some(Arc::new(vectorizer));
203 self
204 }
205
206 #[cfg(feature = "hf-local")]
216 pub fn with_default_vectorizer(self) -> Result<Self> {
217 let vectorizer = crate::vectorizers::HuggingFaceTextVectorizer::new(Default::default())?;
218 Ok(self.with_vectorizer(vectorizer))
219 }
220
221 pub fn set_vectorizer<V>(&mut self, vectorizer: V)
223 where
224 V: Vectorizer + 'static,
225 {
226 self.vectorizer = Some(Arc::new(vectorizer));
227 }
228
229 pub fn ttl(&self) -> Option<u64> {
231 self.config.ttl_seconds
232 }
233
234 pub fn set_ttl(&mut self, ttl_seconds: Option<u64>) {
236 self.config.ttl_seconds = ttl_seconds;
237 }
238
239 pub fn set_threshold(&mut self, distance_threshold: f32) -> Result<()> {
241 validate_distance_threshold(distance_threshold)?;
242 self.distance_threshold = distance_threshold;
243 Ok(())
244 }
245
246 pub fn store(
248 &self,
249 prompt: &str,
250 response: &str,
251 vector: Option<&[f32]>,
252 metadata: Option<Value>,
253 filters: Option<Map<String, Value>>,
254 ttl_seconds: Option<u64>,
255 ) -> Result<String> {
256 if let Some(metadata) = metadata.as_ref() {
257 validate_metadata(metadata)?;
258 }
259
260 let vector = self.resolve_vector(prompt, vector)?;
261 let timestamp = current_timestamp();
262 let entry_id = semantic_entry_id(prompt, filters.as_ref());
263 let mut record = Map::new();
264 record.insert(SEMANTIC_ENTRY_ID_FIELD.to_owned(), Value::String(entry_id));
265 record.insert(
266 SEMANTIC_PROMPT_FIELD.to_owned(),
267 Value::String(prompt.to_owned()),
268 );
269 record.insert(
270 SEMANTIC_RESPONSE_FIELD.to_owned(),
271 Value::String(response.to_owned()),
272 );
273 record.insert(
274 SEMANTIC_VECTOR_FIELD.to_owned(),
275 Value::Array(
276 vector
277 .iter()
278 .copied()
279 .map(|value| number_value(f64::from(value)))
280 .collect(),
281 ),
282 );
283 record.insert(
284 SEMANTIC_INSERTED_AT_FIELD.to_owned(),
285 number_value(timestamp),
286 );
287 record.insert(
288 SEMANTIC_UPDATED_AT_FIELD.to_owned(),
289 number_value(timestamp),
290 );
291 if let Some(metadata) = metadata {
292 record.insert(SEMANTIC_METADATA_FIELD.to_owned(), metadata);
293 }
294 if let Some(filters) = filters {
295 for (key, value) in filters {
296 record.insert(key, value);
297 }
298 }
299
300 let keys = self.index.load(
301 &[Value::Object(record)],
302 SEMANTIC_ENTRY_ID_FIELD,
303 ttl_seconds
304 .or(self.config.ttl_seconds)
305 .map(|value| value as i64),
306 )?;
307 Ok(keys.into_iter().next().unwrap_or_default())
308 }
309
310 pub async fn astore(
312 &self,
313 prompt: &str,
314 response: &str,
315 vector: Option<&[f32]>,
316 metadata: Option<Value>,
317 filters: Option<Map<String, Value>>,
318 ttl_seconds: Option<u64>,
319 ) -> Result<String> {
320 if let Some(metadata) = metadata.as_ref() {
321 validate_metadata(metadata)?;
322 }
323
324 let vector = self.resolve_vector(prompt, vector)?;
325 let timestamp = current_timestamp();
326 let entry_id = semantic_entry_id(prompt, filters.as_ref());
327 let mut record = Map::new();
328 record.insert(SEMANTIC_ENTRY_ID_FIELD.to_owned(), Value::String(entry_id));
329 record.insert(
330 SEMANTIC_PROMPT_FIELD.to_owned(),
331 Value::String(prompt.to_owned()),
332 );
333 record.insert(
334 SEMANTIC_RESPONSE_FIELD.to_owned(),
335 Value::String(response.to_owned()),
336 );
337 record.insert(
338 SEMANTIC_VECTOR_FIELD.to_owned(),
339 Value::Array(
340 vector
341 .iter()
342 .copied()
343 .map(|value| number_value(f64::from(value)))
344 .collect(),
345 ),
346 );
347 record.insert(
348 SEMANTIC_INSERTED_AT_FIELD.to_owned(),
349 number_value(timestamp),
350 );
351 record.insert(
352 SEMANTIC_UPDATED_AT_FIELD.to_owned(),
353 number_value(timestamp),
354 );
355 if let Some(metadata) = metadata {
356 record.insert(SEMANTIC_METADATA_FIELD.to_owned(), metadata);
357 }
358 if let Some(filters) = filters {
359 for (key, value) in filters {
360 record.insert(key, value);
361 }
362 }
363
364 let keys = self
365 .async_index()
366 .load(
367 &[Value::Object(record)],
368 SEMANTIC_ENTRY_ID_FIELD,
369 ttl_seconds
370 .or(self.config.ttl_seconds)
371 .map(|value| value as i64),
372 )
373 .await?;
374 Ok(keys.into_iter().next().unwrap_or_default())
375 }
376
377 pub fn check(
379 &self,
380 prompt: Option<&str>,
381 vector: Option<&[f32]>,
382 num_results: usize,
383 return_fields: Option<&[&str]>,
384 filter_expression: Option<FilterExpression>,
385 distance_threshold: Option<f32>,
386 ) -> Result<Vec<Map<String, Value>>> {
387 let vector = self.resolve_query_vector(prompt, vector)?;
388 let threshold = distance_threshold.unwrap_or(self.distance_threshold);
389 validate_distance_threshold(threshold)?;
390 let mut query = VectorRangeQuery::new(
391 Vector::new(vector.clone()),
392 SEMANTIC_VECTOR_FIELD,
393 threshold,
394 )
395 .paging(0, num_results)
396 .with_return_fields(self.return_fields.iter().map(String::as_str));
397 if let Some(filter_expression) = filter_expression {
398 query = query.with_filter(filter_expression);
399 }
400
401 let hits = process_semantic_hits(
402 query_output_documents(self.index.query(&query)?)?,
403 return_fields,
404 )?;
405 self.refresh_ttl_sync(&hits)?;
406 Ok(hits)
407 }
408
409 pub async fn acheck(
411 &self,
412 prompt: Option<&str>,
413 vector: Option<&[f32]>,
414 num_results: usize,
415 return_fields: Option<&[&str]>,
416 filter_expression: Option<FilterExpression>,
417 distance_threshold: Option<f32>,
418 ) -> Result<Vec<Map<String, Value>>> {
419 let vector = self.resolve_query_vector(prompt, vector)?;
420 let threshold = distance_threshold.unwrap_or(self.distance_threshold);
421 validate_distance_threshold(threshold)?;
422 let mut query = VectorRangeQuery::new(
423 Vector::new(vector.clone()),
424 SEMANTIC_VECTOR_FIELD,
425 threshold,
426 )
427 .paging(0, num_results)
428 .with_return_fields(self.return_fields.iter().map(String::as_str));
429 if let Some(filter_expression) = filter_expression {
430 query = query.with_filter(filter_expression);
431 }
432
433 let hits = process_semantic_hits(
434 query_output_documents(self.async_index().query(&query).await?)?,
435 return_fields,
436 )?;
437 self.refresh_ttl_async(&hits).await?;
438 Ok(hits)
439 }
440
441 pub fn update(&self, key: &str, fields: Map<String, Value>) -> Result<()> {
443 let mapping = prepare_semantic_update_fields(fields)?;
444 let client = self.config.connection.client()?;
445 let mut connection = client.get_connection()?;
446 let mut cmd = redis::cmd("HSET");
447 cmd.arg(key);
448 for (field, value) in mapping {
449 cmd.arg(field).arg(value);
450 }
451 let _: usize = cmd.query(&mut connection)?;
452 self.expire_key(key, None)
453 }
454
455 pub async fn aupdate(&self, key: &str, fields: Map<String, Value>) -> Result<()> {
457 let mapping = prepare_semantic_update_fields(fields)?;
458 let client = self.config.connection.client()?;
459 let mut connection = client.get_multiplexed_async_connection().await?;
460 let mut cmd = redis::cmd("HSET");
461 cmd.arg(key);
462 for (field, value) in mapping {
463 cmd.arg(field).arg(value);
464 }
465 let _: usize = cmd.query_async(&mut connection).await?;
466 self.aexpire_key(key, None).await
467 }
468
469 pub fn clear(&self) -> Result<usize> {
471 self.index.clear()
472 }
473
474 pub async fn aclear(&self) -> Result<usize> {
476 self.async_index().clear().await
477 }
478
479 pub fn delete(&self) -> Result<()> {
481 self.index.delete(true)
482 }
483
484 pub async fn adelete(&self) -> Result<()> {
486 self.async_index().delete(true).await
487 }
488
489 pub fn drop_ids(&self, ids: &[String]) -> Result<()> {
491 let keys = ids.iter().map(|id| self.index.key(id)).collect::<Vec<_>>();
492 self.index.drop_keys(&keys)?;
493 Ok(())
494 }
495
496 pub fn drop_keys(&self, keys: &[String]) -> Result<()> {
498 self.index.drop_keys(keys)?;
499 Ok(())
500 }
501
502 pub async fn adrop_ids(&self, ids: &[String]) -> Result<()> {
504 let keys = ids.iter().map(|id| self.index.key(id)).collect::<Vec<_>>();
505 self.async_index().drop_keys(&keys).await?;
506 Ok(())
507 }
508
509 pub async fn adrop_keys(&self, keys: &[String]) -> Result<()> {
511 self.async_index().drop_keys(keys).await?;
512 Ok(())
513 }
514
515 fn resolve_query_vector(
516 &self,
517 prompt: Option<&str>,
518 vector: Option<&[f32]>,
519 ) -> Result<Vec<f32>> {
520 match (prompt, vector) {
521 (_, Some(vector)) => self.validate_vector(vector),
522 (Some(prompt), None) => self.resolve_vector(prompt, None),
523 (None, None) => Err(crate::Error::InvalidInput(
524 "either prompt or vector must be specified".to_owned(),
525 )),
526 }
527 }
528
529 fn resolve_vector(&self, prompt: &str, vector: Option<&[f32]>) -> Result<Vec<f32>> {
530 match vector {
531 Some(vector) => self.validate_vector(vector),
532 None => {
533 let Some(vectorizer) = &self.vectorizer else {
534 return Err(crate::Error::InvalidInput(
535 "a vector or configured vectorizer is required".to_owned(),
536 ));
537 };
538 let vector = vectorizer.embed(prompt)?;
539 self.validate_vector(&vector)
540 }
541 }
542 }
543
544 fn validate_vector(&self, vector: &[f32]) -> Result<Vec<f32>> {
545 if vector.len() != self.vector_dimensions {
546 return Err(crate::Error::InvalidInput(format!(
547 "vector dimensions mismatch: expected {}, got {}",
548 self.vector_dimensions,
549 vector.len()
550 )));
551 }
552 Ok(vector.to_vec())
553 }
554
555 fn async_index(&self) -> AsyncSearchIndex {
556 AsyncSearchIndex::new(
557 self.index.schema().clone(),
558 self.config.connection.redis_url.clone(),
559 )
560 }
561
562 fn refresh_ttl_sync(&self, hits: &[Map<String, Value>]) -> Result<()> {
563 if self.config.ttl_seconds.is_none() {
564 return Ok(());
565 }
566 for hit in hits {
567 if let Some(key) = hit.get(SEMANTIC_KEY_FIELD).and_then(Value::as_str) {
568 self.expire_key(key, None)?;
569 }
570 }
571 Ok(())
572 }
573
574 async fn refresh_ttl_async(&self, hits: &[Map<String, Value>]) -> Result<()> {
575 if self.config.ttl_seconds.is_none() {
576 return Ok(());
577 }
578 for hit in hits {
579 if let Some(key) = hit.get(SEMANTIC_KEY_FIELD).and_then(Value::as_str) {
580 self.aexpire_key(key, None).await?;
581 }
582 }
583 Ok(())
584 }
585
586 fn expire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
587 if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
588 let client = self.config.connection.client()?;
589 let mut connection = client.get_connection()?;
590 let _: bool = redis::cmd("EXPIRE")
591 .arg(key)
592 .arg(ttl_seconds)
593 .query(&mut connection)?;
594 }
595 Ok(())
596 }
597
598 async fn aexpire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
599 if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
600 let client = self.config.connection.client()?;
601 let mut connection = client.get_multiplexed_async_connection().await?;
602 let _: bool = redis::cmd("EXPIRE")
603 .arg(key)
604 .arg(ttl_seconds)
605 .query_async(&mut connection)
606 .await?;
607 }
608 Ok(())
609 }
610}
611
612impl std::fmt::Debug for SemanticCache {
613 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614 formatter
615 .debug_struct("SemanticCache")
616 .field("config", &self.config)
617 .field("distance_threshold", &self.distance_threshold)
618 .field("vector_dimensions", &self.vector_dimensions)
619 .field("index_name", &self.index.name())
620 .finish()
621 }
622}
623
624#[derive(Debug, Clone)]
626pub struct EmbeddingsCache {
627 pub config: CacheConfig,
629}
630
631impl Default for EmbeddingsCache {
632 fn default() -> Self {
633 Self::new(CacheConfig::default())
634 }
635}
636
637impl EmbeddingsCache {
638 pub fn new(config: CacheConfig) -> Self {
640 Self { config }
641 }
642
643 pub fn make_entry_id(&self, content: &str, model_name: &str) -> String {
645 hashify(&format!("{content}:{model_name}"))
646 }
647
648 pub fn make_cache_key(&self, content: &str, model_name: &str) -> String {
650 let entry_id = self.make_entry_id(content, model_name);
651 self.key_for_entry(&entry_id)
652 }
653
654 pub fn get(&self, content: &str, model_name: &str) -> Result<Option<EmbeddingCacheEntry>> {
656 let key = self.make_cache_key(content, model_name);
657 self.get_by_key(&key)
658 }
659
660 pub fn get_by_key(&self, key: &str) -> Result<Option<EmbeddingCacheEntry>> {
662 let client = self.config.connection.client()?;
663 let mut connection = client.get_connection()?;
664 let data: HashMap<String, String> =
665 redis::cmd("HGETALL").arg(key).query(&mut connection)?;
666
667 if data.is_empty() {
668 return Ok(None);
669 }
670
671 self.expire_key(key, None)?;
672 parse_entry(data)
673 }
674
675 pub fn mget<I, S>(
677 &self,
678 contents: I,
679 model_name: &str,
680 ) -> Result<Vec<Option<EmbeddingCacheEntry>>>
681 where
682 I: IntoIterator<Item = S>,
683 S: AsRef<str>,
684 {
685 let keys = contents
686 .into_iter()
687 .map(|content| self.make_cache_key(content.as_ref(), model_name))
688 .collect::<Vec<_>>();
689 self.mget_by_keys(keys)
690 }
691
692 pub fn mget_by_keys<I, S>(&self, keys: I) -> Result<Vec<Option<EmbeddingCacheEntry>>>
694 where
695 I: IntoIterator<Item = S>,
696 S: AsRef<str>,
697 {
698 let keys = collect_strings(keys);
699 if keys.is_empty() {
700 return Ok(Vec::new());
701 }
702
703 let mut results = Vec::with_capacity(keys.len());
704 for key in &keys {
705 results.push(self.get_by_key(key)?);
706 }
707 Ok(results)
708 }
709
710 pub fn set(
712 &self,
713 content: &str,
714 model_name: &str,
715 embedding: &[f32],
716 metadata: Option<Value>,
717 ttl_seconds: Option<u64>,
718 ) -> Result<String> {
719 let entry = self.prepare_entry(content, model_name, embedding, metadata);
720 let key = self.key_for_entry(&entry.entry_id);
721 self.write_entry(&key, &entry)?;
722 self.expire_key(&key, ttl_seconds)?;
723 Ok(key)
724 }
725
726 pub fn mset(
728 &self,
729 items: &[EmbeddingCacheItem],
730 ttl_seconds: Option<u64>,
731 ) -> Result<Vec<String>> {
732 let mut keys = Vec::with_capacity(items.len());
733 for item in items {
734 let key = self.set(
735 &item.content,
736 &item.model_name,
737 &item.embedding,
738 item.metadata.clone(),
739 ttl_seconds,
740 )?;
741 keys.push(key);
742 }
743 Ok(keys)
744 }
745
746 pub fn exists(&self, content: &str, model_name: &str) -> Result<bool> {
748 let key = self.make_cache_key(content, model_name);
749 self.exists_by_key(&key)
750 }
751
752 pub fn exists_by_key(&self, key: &str) -> Result<bool> {
754 let client = self.config.connection.client()?;
755 let mut connection = client.get_connection()?;
756 let exists: u64 = redis::cmd("EXISTS").arg(key).query(&mut connection)?;
757 Ok(exists > 0)
758 }
759
760 pub fn mexists<I, S>(&self, contents: I, model_name: &str) -> Result<Vec<bool>>
762 where
763 I: IntoIterator<Item = S>,
764 S: AsRef<str>,
765 {
766 let keys = contents
767 .into_iter()
768 .map(|content| self.make_cache_key(content.as_ref(), model_name))
769 .collect::<Vec<_>>();
770 self.mexists_by_keys(keys)
771 }
772
773 pub fn mexists_by_keys<I, S>(&self, keys: I) -> Result<Vec<bool>>
775 where
776 I: IntoIterator<Item = S>,
777 S: AsRef<str>,
778 {
779 let keys = collect_strings(keys);
780 if keys.is_empty() {
781 return Ok(Vec::new());
782 }
783
784 let client = self.config.connection.client()?;
785 let mut connection = client.get_connection()?;
786 let mut results = Vec::with_capacity(keys.len());
787 for key in keys {
788 let exists: u64 = redis::cmd("EXISTS").arg(key).query(&mut connection)?;
789 results.push(exists > 0);
790 }
791 Ok(results)
792 }
793
794 pub fn drop(&self, content: &str, model_name: &str) -> Result<()> {
796 let key = self.make_cache_key(content, model_name);
797 self.drop_by_key(&key)
798 }
799
800 pub fn drop_by_key(&self, key: &str) -> Result<()> {
802 let client = self.config.connection.client()?;
803 let mut connection = client.get_connection()?;
804 let _: usize = redis::cmd("DEL").arg(key).query(&mut connection)?;
805 Ok(())
806 }
807
808 pub fn mdrop<I, S>(&self, contents: I, model_name: &str) -> Result<()>
810 where
811 I: IntoIterator<Item = S>,
812 S: AsRef<str>,
813 {
814 let keys = contents
815 .into_iter()
816 .map(|content| self.make_cache_key(content.as_ref(), model_name))
817 .collect::<Vec<_>>();
818 self.mdrop_by_keys(keys)
819 }
820
821 pub fn mdrop_by_keys<I, S>(&self, keys: I) -> Result<()>
823 where
824 I: IntoIterator<Item = S>,
825 S: AsRef<str>,
826 {
827 let keys = collect_strings(keys);
828 if keys.is_empty() {
829 return Ok(());
830 }
831
832 let client = self.config.connection.client()?;
833 let mut connection = client.get_connection()?;
834 let _: usize = redis::cmd("DEL").arg(keys).query(&mut connection)?;
835 Ok(())
836 }
837
838 pub fn clear(&self) -> Result<usize> {
840 let keys = self.all_keys()?;
841 if keys.is_empty() {
842 return Ok(0);
843 }
844
845 let count = keys.len();
846 self.mdrop_by_keys(keys)?;
847 Ok(count)
848 }
849
850 pub async fn aget(
852 &self,
853 content: &str,
854 model_name: &str,
855 ) -> Result<Option<EmbeddingCacheEntry>> {
856 let key = self.make_cache_key(content, model_name);
857 self.aget_by_key(&key).await
858 }
859
860 pub async fn aget_by_key(&self, key: &str) -> Result<Option<EmbeddingCacheEntry>> {
862 let client = self.config.connection.client()?;
863 let mut connection = client.get_multiplexed_async_connection().await?;
864 let data: HashMap<String, String> = redis::cmd("HGETALL")
865 .arg(key)
866 .query_async(&mut connection)
867 .await?;
868
869 if data.is_empty() {
870 return Ok(None);
871 }
872
873 self.aexpire_key(key, None).await?;
874 parse_entry(data)
875 }
876
877 pub async fn amget<I, S>(
879 &self,
880 contents: I,
881 model_name: &str,
882 ) -> Result<Vec<Option<EmbeddingCacheEntry>>>
883 where
884 I: IntoIterator<Item = S>,
885 S: AsRef<str>,
886 {
887 let keys = contents
888 .into_iter()
889 .map(|content| self.make_cache_key(content.as_ref(), model_name))
890 .collect::<Vec<_>>();
891 self.amget_by_keys(keys).await
892 }
893
894 pub async fn amget_by_keys<I, S>(&self, keys: I) -> Result<Vec<Option<EmbeddingCacheEntry>>>
896 where
897 I: IntoIterator<Item = S>,
898 S: AsRef<str>,
899 {
900 let keys = collect_strings(keys);
901 if keys.is_empty() {
902 return Ok(Vec::new());
903 }
904
905 let mut results = Vec::with_capacity(keys.len());
906 for key in &keys {
907 results.push(self.aget_by_key(key).await?);
908 }
909 Ok(results)
910 }
911
912 pub async fn aset(
914 &self,
915 content: &str,
916 model_name: &str,
917 embedding: &[f32],
918 metadata: Option<Value>,
919 ttl_seconds: Option<u64>,
920 ) -> Result<String> {
921 let entry = self.prepare_entry(content, model_name, embedding, metadata);
922 let key = self.key_for_entry(&entry.entry_id);
923 self.awrite_entry(&key, &entry).await?;
924 self.aexpire_key(&key, ttl_seconds).await?;
925 Ok(key)
926 }
927
928 pub async fn amset(
930 &self,
931 items: &[EmbeddingCacheItem],
932 ttl_seconds: Option<u64>,
933 ) -> Result<Vec<String>> {
934 let mut keys = Vec::with_capacity(items.len());
935 for item in items {
936 let key = self
937 .aset(
938 &item.content,
939 &item.model_name,
940 &item.embedding,
941 item.metadata.clone(),
942 ttl_seconds,
943 )
944 .await?;
945 keys.push(key);
946 }
947 Ok(keys)
948 }
949
950 pub async fn aexists(&self, content: &str, model_name: &str) -> Result<bool> {
952 let key = self.make_cache_key(content, model_name);
953 self.aexists_by_key(&key).await
954 }
955
956 pub async fn aexists_by_key(&self, key: &str) -> Result<bool> {
958 let client = self.config.connection.client()?;
959 let mut connection = client.get_multiplexed_async_connection().await?;
960 Ok(connection.exists(key).await?)
961 }
962
963 pub async fn amexists<I, S>(&self, contents: I, model_name: &str) -> Result<Vec<bool>>
965 where
966 I: IntoIterator<Item = S>,
967 S: AsRef<str>,
968 {
969 let keys = contents
970 .into_iter()
971 .map(|content| self.make_cache_key(content.as_ref(), model_name))
972 .collect::<Vec<_>>();
973 self.amexists_by_keys(keys).await
974 }
975
976 pub async fn amexists_by_keys<I, S>(&self, keys: I) -> Result<Vec<bool>>
978 where
979 I: IntoIterator<Item = S>,
980 S: AsRef<str>,
981 {
982 let keys = collect_strings(keys);
983 if keys.is_empty() {
984 return Ok(Vec::new());
985 }
986
987 let client = self.config.connection.client()?;
988 let mut connection = client.get_multiplexed_async_connection().await?;
989 let mut results = Vec::with_capacity(keys.len());
990 for key in keys {
991 results.push(connection.exists(key).await?);
992 }
993 Ok(results)
994 }
995
996 pub async fn adrop(&self, content: &str, model_name: &str) -> Result<()> {
998 let key = self.make_cache_key(content, model_name);
999 self.adrop_by_key(&key).await
1000 }
1001
1002 pub async fn adrop_by_key(&self, key: &str) -> Result<()> {
1004 let client = self.config.connection.client()?;
1005 let mut connection = client.get_multiplexed_async_connection().await?;
1006 let _: usize = connection.del(key).await?;
1007 Ok(())
1008 }
1009
1010 pub async fn amdrop<I, S>(&self, contents: I, model_name: &str) -> Result<()>
1012 where
1013 I: IntoIterator<Item = S>,
1014 S: AsRef<str>,
1015 {
1016 let keys = contents
1017 .into_iter()
1018 .map(|content| self.make_cache_key(content.as_ref(), model_name))
1019 .collect::<Vec<_>>();
1020 self.amdrop_by_keys(keys).await
1021 }
1022
1023 pub async fn amdrop_by_keys<I, S>(&self, keys: I) -> Result<()>
1025 where
1026 I: IntoIterator<Item = S>,
1027 S: AsRef<str>,
1028 {
1029 let keys = collect_strings(keys);
1030 if keys.is_empty() {
1031 return Ok(());
1032 }
1033
1034 let client = self.config.connection.client()?;
1035 let mut connection = client.get_multiplexed_async_connection().await?;
1036 let _: usize = connection.del(keys).await?;
1037 Ok(())
1038 }
1039
1040 pub async fn aclear(&self) -> Result<usize> {
1042 let keys = self.aall_keys().await?;
1043 if keys.is_empty() {
1044 return Ok(0);
1045 }
1046
1047 let count = keys.len();
1048 self.amdrop_by_keys(keys).await?;
1049 Ok(count)
1050 }
1051
1052 fn prepare_entry(
1053 &self,
1054 content: &str,
1055 model_name: &str,
1056 embedding: &[f32],
1057 metadata: Option<Value>,
1058 ) -> EmbeddingCacheEntry {
1059 EmbeddingCacheEntry {
1060 entry_id: self.make_entry_id(content, model_name),
1061 content: content.to_owned(),
1062 model_name: model_name.to_owned(),
1063 embedding: embedding.to_vec(),
1064 metadata,
1065 }
1066 }
1067
1068 fn write_entry(&self, key: &str, entry: &EmbeddingCacheEntry) -> Result<()> {
1069 let client = self.config.connection.client()?;
1070 let mut connection = client.get_connection()?;
1071 let mut cmd = redis::cmd("HSET");
1072 cmd.arg(key)
1073 .arg("entry_id")
1074 .arg(&entry.entry_id)
1075 .arg("content")
1076 .arg(&entry.content)
1077 .arg("model_name")
1078 .arg(&entry.model_name)
1079 .arg("embedding")
1080 .arg(serde_json::to_string(&entry.embedding)?);
1081
1082 if let Some(metadata) = &entry.metadata {
1083 cmd.arg("metadata").arg(serde_json::to_string(metadata)?);
1084 }
1085
1086 let _: usize = cmd.query(&mut connection)?;
1087 Ok(())
1088 }
1089
1090 async fn awrite_entry(&self, key: &str, entry: &EmbeddingCacheEntry) -> Result<()> {
1091 let client = self.config.connection.client()?;
1092 let mut connection = client.get_multiplexed_async_connection().await?;
1093 let mut cmd = redis::cmd("HSET");
1094 cmd.arg(key)
1095 .arg("entry_id")
1096 .arg(&entry.entry_id)
1097 .arg("content")
1098 .arg(&entry.content)
1099 .arg("model_name")
1100 .arg(&entry.model_name)
1101 .arg("embedding")
1102 .arg(serde_json::to_string(&entry.embedding)?);
1103
1104 if let Some(metadata) = &entry.metadata {
1105 cmd.arg("metadata").arg(serde_json::to_string(metadata)?);
1106 }
1107
1108 let _: usize = cmd.query_async(&mut connection).await?;
1109 Ok(())
1110 }
1111
1112 fn expire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
1113 if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
1114 let client = self.config.connection.client()?;
1115 let mut connection = client.get_connection()?;
1116 let _: bool = redis::cmd("EXPIRE")
1117 .arg(key)
1118 .arg(ttl_seconds)
1119 .query(&mut connection)?;
1120 }
1121 Ok(())
1122 }
1123
1124 async fn aexpire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
1125 if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
1126 let client = self.config.connection.client()?;
1127 let mut connection = client.get_multiplexed_async_connection().await?;
1128 let _: bool = redis::cmd("EXPIRE")
1129 .arg(key)
1130 .arg(ttl_seconds)
1131 .query_async(&mut connection)
1132 .await?;
1133 }
1134 Ok(())
1135 }
1136
1137 fn all_keys(&self) -> Result<Vec<String>> {
1138 let client = self.config.connection.client()?;
1139 let mut connection = client.get_connection()?;
1140 let keys: Vec<String> = redis::cmd("KEYS")
1141 .arg(format!("{}:*", self.config.name))
1142 .query(&mut connection)?;
1143 Ok(keys)
1144 }
1145
1146 async fn aall_keys(&self) -> Result<Vec<String>> {
1147 let client = self.config.connection.client()?;
1148 let mut connection = client.get_multiplexed_async_connection().await?;
1149 let keys: Vec<String> = redis::cmd("KEYS")
1150 .arg(format!("{}:*", self.config.name))
1151 .query_async(&mut connection)
1152 .await?;
1153 Ok(keys)
1154 }
1155
1156 fn key_for_entry(&self, entry_id: &str) -> String {
1157 format!("{}:{entry_id}", self.config.name)
1158 }
1159}
1160
1161fn collect_strings<I, S>(values: I) -> Vec<String>
1162where
1163 I: IntoIterator<Item = S>,
1164 S: AsRef<str>,
1165{
1166 values
1167 .into_iter()
1168 .map(|value| value.as_ref().to_owned())
1169 .collect()
1170}
1171
1172fn parse_entry(data: HashMap<String, String>) -> Result<Option<EmbeddingCacheEntry>> {
1173 if data.is_empty() {
1174 return Ok(None);
1175 }
1176
1177 let entry = EmbeddingCacheEntry {
1178 entry_id: data.get("entry_id").cloned().unwrap_or_default(),
1179 content: data.get("content").cloned().unwrap_or_default(),
1180 model_name: data.get("model_name").cloned().unwrap_or_default(),
1181 embedding: match data.get("embedding") {
1182 Some(value) => serde_json::from_str::<Vec<f32>>(value)?,
1183 None => Vec::new(),
1184 },
1185 metadata: data
1186 .get("metadata")
1187 .map(|value| serde_json::from_str::<Value>(value))
1188 .transpose()?,
1189 };
1190
1191 Ok(Some(entry))
1192}
1193
1194fn hashify(content: &str) -> String {
1195 let mut hasher = Sha256::new();
1196 hasher.update(content.as_bytes());
1197 let digest = hasher.finalize();
1198 let mut output = String::with_capacity(digest.len() * 2);
1199 for byte in digest {
1200 use std::fmt::Write as _;
1201 let _ = write!(&mut output, "{byte:02x}");
1202 }
1203 output
1204}
1205
1206fn semantic_cache_schema(
1207 name: &str,
1208 vector_dimensions: usize,
1209 dtype: VectorDataType,
1210 filterable_fields: &[Value],
1211) -> Value {
1212 let mut fields = vec![
1213 json!({ "name": SEMANTIC_ENTRY_ID_FIELD, "type": "tag" }),
1214 json!({ "name": SEMANTIC_PROMPT_FIELD, "type": "text" }),
1215 json!({ "name": SEMANTIC_RESPONSE_FIELD, "type": "text" }),
1216 json!({ "name": SEMANTIC_INSERTED_AT_FIELD, "type": "numeric" }),
1217 json!({ "name": SEMANTIC_UPDATED_AT_FIELD, "type": "numeric" }),
1218 json!({ "name": SEMANTIC_METADATA_FIELD, "type": "text" }),
1219 json!({
1220 "name": SEMANTIC_VECTOR_FIELD,
1221 "type": "vector",
1222 "attrs": {
1223 "algorithm": "flat",
1224 "dims": vector_dimensions,
1225 "datatype": dtype.as_str(),
1226 "distance_metric": "cosine"
1227 }
1228 }),
1229 ];
1230 fields.extend(filterable_fields.iter().cloned());
1231 json!({
1232 "index": {
1233 "name": name,
1234 "prefix": name,
1235 "storage_type": "hash",
1236 },
1237 "fields": fields,
1238 })
1239}
1240
1241fn default_semantic_return_fields() -> Vec<String> {
1242 vec![
1243 SEMANTIC_ENTRY_ID_FIELD.to_owned(),
1244 SEMANTIC_PROMPT_FIELD.to_owned(),
1245 SEMANTIC_RESPONSE_FIELD.to_owned(),
1246 "vector_distance".to_owned(),
1247 SEMANTIC_INSERTED_AT_FIELD.to_owned(),
1248 SEMANTIC_UPDATED_AT_FIELD.to_owned(),
1249 SEMANTIC_METADATA_FIELD.to_owned(),
1250 ]
1251}
1252
1253fn current_timestamp() -> f64 {
1254 Utc::now().timestamp_millis() as f64 / 1000.0
1255}
1256
1257fn semantic_entry_id(prompt: &str, filters: Option<&Map<String, Value>>) -> String {
1258 if let Some(filters) = filters {
1259 let mut parts = filters
1260 .iter()
1261 .map(|(key, value)| format!("{key}{}", value_to_hash_string(value)))
1262 .collect::<Vec<_>>();
1263 parts.sort();
1264 hashify(&format!("{prompt}{}", parts.join("")))
1265 } else {
1266 hashify(prompt)
1267 }
1268}
1269
1270fn value_to_hash_string(value: &Value) -> String {
1271 match value {
1272 Value::Null => "null".to_owned(),
1273 Value::Bool(value) => value.to_string(),
1274 Value::Number(value) => value.to_string(),
1275 Value::String(value) => value.clone(),
1276 Value::Array(_) | Value::Object(_) => serde_json::to_string(value).unwrap_or_default(),
1277 }
1278}
1279
1280const RESERVED_SEMANTIC_FIELDS: &[&str] = &[
1282 SEMANTIC_ENTRY_ID_FIELD,
1283 SEMANTIC_PROMPT_FIELD,
1284 SEMANTIC_RESPONSE_FIELD,
1285 SEMANTIC_VECTOR_FIELD,
1286 SEMANTIC_INSERTED_AT_FIELD,
1287 SEMANTIC_UPDATED_AT_FIELD,
1288 SEMANTIC_METADATA_FIELD,
1289 SEMANTIC_KEY_FIELD,
1290 "vector_distance",
1291];
1292
1293fn validate_filterable_fields(fields: &[Value]) -> Result<()> {
1294 let mut seen = std::collections::HashSet::new();
1295 for field in fields {
1296 let name = field
1297 .get("name")
1298 .and_then(Value::as_str)
1299 .unwrap_or_default();
1300 let field_type = field
1301 .get("type")
1302 .and_then(Value::as_str)
1303 .unwrap_or_default();
1304
1305 if name.is_empty() {
1306 return Err(crate::Error::InvalidInput(
1307 "filterable field must have a non-empty 'name'".to_owned(),
1308 ));
1309 }
1310
1311 if RESERVED_SEMANTIC_FIELDS.contains(&name) {
1312 return Err(crate::Error::InvalidInput(format!(
1313 "{name} is a reserved field name for the semantic cache schema"
1314 )));
1315 }
1316
1317 if !seen.insert(name.to_owned()) {
1318 return Err(crate::Error::InvalidInput(format!(
1319 "duplicate field name: {name}. Field names must be unique"
1320 )));
1321 }
1322
1323 if !matches!(field_type, "tag" | "text" | "numeric" | "geo") {
1324 return Err(crate::Error::InvalidInput(format!(
1325 "invalid filterable field type: '{field_type}' for field '{name}'"
1326 )));
1327 }
1328 }
1329 Ok(())
1330}
1331
1332fn validate_distance_threshold(distance_threshold: f32) -> Result<()> {
1333 if !(0.0..=2.0).contains(&distance_threshold) {
1334 return Err(crate::Error::InvalidInput(format!(
1335 "distance threshold must be between 0 and 2, got {distance_threshold}"
1336 )));
1337 }
1338 Ok(())
1339}
1340
1341fn validate_metadata(metadata: &Value) -> Result<()> {
1342 if !metadata.is_object() {
1343 return Err(crate::Error::InvalidInput(
1344 "metadata must be a JSON object".to_owned(),
1345 ));
1346 }
1347 Ok(())
1348}
1349
1350fn query_output_documents(output: QueryOutput) -> Result<Vec<Map<String, Value>>> {
1351 match output {
1352 QueryOutput::Documents(documents) => Ok(documents),
1353 QueryOutput::Count(_) => Err(crate::Error::InvalidInput(
1354 "semantic cache queries must return documents".to_owned(),
1355 )),
1356 }
1357}
1358
1359fn process_semantic_hits(
1360 documents: Vec<Map<String, Value>>,
1361 return_fields: Option<&[&str]>,
1362) -> Result<Vec<Map<String, Value>>> {
1363 let selected = return_fields.map(|fields| {
1364 fields
1365 .iter()
1366 .map(|field| (*field).to_owned())
1367 .collect::<std::collections::HashSet<_>>()
1368 });
1369 let mut hits = Vec::with_capacity(documents.len());
1370 for mut document in documents {
1371 let key = document
1372 .remove("id")
1373 .unwrap_or_else(|| Value::String(String::new()));
1374 let mut hit = Map::new();
1375 hit.insert(SEMANTIC_KEY_FIELD.to_owned(), key);
1376 for (field, value) in document {
1377 let include = selected
1378 .as_ref()
1379 .is_none_or(|fields| fields.contains(&field));
1380 if !include {
1381 continue;
1382 }
1383 hit.insert(field.clone(), normalize_semantic_value(&field, value)?);
1384 }
1385 hits.push(hit);
1386 }
1387 Ok(hits)
1388}
1389
1390fn normalize_semantic_value(field: &str, value: Value) -> Result<Value> {
1391 match (field, value) {
1392 (SEMANTIC_METADATA_FIELD, Value::String(value)) => {
1393 Ok(serde_json::from_str(&value).unwrap_or(Value::String(value)))
1394 }
1395 (
1396 "vector_distance" | SEMANTIC_INSERTED_AT_FIELD | SEMANTIC_UPDATED_AT_FIELD,
1397 Value::String(value),
1398 ) => {
1399 let parsed = value.parse::<f64>().map_err(|_| {
1400 crate::Error::InvalidInput(format!("could not parse numeric field '{field}'"))
1401 })?;
1402 Ok(number_value(parsed))
1403 }
1404 (_, value) => Ok(value),
1405 }
1406}
1407
1408fn prepare_semantic_update_fields(fields: Map<String, Value>) -> Result<Vec<(String, String)>> {
1409 let mut mapping = Vec::with_capacity(fields.len() + 1);
1410 for (field, value) in fields {
1411 if field == SEMANTIC_VECTOR_FIELD {
1412 return Err(crate::Error::InvalidInput(
1413 "updating the stored vector is not supported yet".to_owned(),
1414 ));
1415 }
1416 if field == SEMANTIC_METADATA_FIELD {
1417 validate_metadata(&value)?;
1418 }
1419 let serialized = match value {
1420 Value::Null => "null".to_owned(),
1421 Value::Bool(value) => value.to_string(),
1422 Value::Number(value) => value.to_string(),
1423 Value::String(value) => value,
1424 Value::Array(_) | Value::Object(_) => serde_json::to_string(&value)?,
1425 };
1426 mapping.push((field, serialized));
1427 }
1428 mapping.push((
1429 SEMANTIC_UPDATED_AT_FIELD.to_owned(),
1430 current_timestamp().to_string(),
1431 ));
1432 Ok(mapping)
1433}
1434
1435fn number_value(value: f64) -> Value {
1436 Number::from_f64(value)
1437 .map(Value::Number)
1438 .unwrap_or(Value::Null)
1439}
1440
1441#[cfg(test)]
1442mod tests {
1443 use serde_json::json;
1444
1445 use super::{
1446 CacheConfig, EmbeddingsCache, hashify, validate_distance_threshold,
1447 validate_filterable_fields, validate_metadata,
1448 };
1449
1450 #[test]
1451 fn hashify_matches_expected_sha256() {
1452 assert_eq!(
1453 hashify("Hello world:text-embedding-ada-002"),
1454 "368dacc611e96e4189a9809faaca1a70b3c3306352bbcfc9ab6291359a5dfca0"
1455 );
1456 }
1457
1458 #[test]
1459 fn cache_key_is_stable() {
1460 let cache = EmbeddingsCache::new(CacheConfig::default());
1461 let key = cache.make_cache_key("Hello world", "text-embedding-ada-002");
1462 assert_eq!(
1463 key,
1464 "embedcache:368dacc611e96e4189a9809faaca1a70b3c3306352bbcfc9ab6291359a5dfca0"
1465 );
1466 }
1467
1468 #[test]
1469 fn entry_id_is_deterministic() {
1470 let cache = EmbeddingsCache::new(CacheConfig::default());
1471 let id1 = cache.make_entry_id("Hello world", "text-embedding-ada-002");
1472 let id2 = cache.make_entry_id("Hello world", "text-embedding-ada-002");
1473 assert_eq!(id1, id2);
1474
1475 let different = cache.make_entry_id("Different text", "text-embedding-ada-002");
1476 assert_ne!(id1, different);
1477 }
1478
1479 #[test]
1480 fn entry_id_different_inputs_differ() {
1481 let cache = EmbeddingsCache::new(CacheConfig::default());
1482 let id_a = cache.make_entry_id("What is machine learning?", "text-embedding-ada-002");
1483 let id_b = cache.make_entry_id("How do neural networks work?", "text-embedding-ada-002");
1484 assert_ne!(id_a, id_b);
1485 }
1486
1487 #[test]
1488 fn cache_key_includes_cache_name() {
1489 let cache_a = EmbeddingsCache::new(CacheConfig::new("cache_a", "redis://localhost:6379"));
1490 let cache_b = EmbeddingsCache::new(CacheConfig::new("cache_b", "redis://localhost:6379"));
1491 let key_a = cache_a.make_cache_key("hello", "model");
1492 let key_b = cache_b.make_cache_key("hello", "model");
1493 assert!(key_a.starts_with("cache_a:"));
1494 assert!(key_b.starts_with("cache_b:"));
1495 assert_ne!(key_a, key_b);
1496 }
1497
1498 #[test]
1499 fn distance_threshold_out_of_range() {
1500 assert!(validate_distance_threshold(-1.0).is_err());
1501 assert!(validate_distance_threshold(2.5).is_err());
1502 assert!(validate_distance_threshold(0.0).is_ok());
1503 assert!(validate_distance_threshold(1.0).is_ok());
1504 assert!(validate_distance_threshold(2.0).is_ok());
1505 }
1506
1507 #[test]
1508 fn metadata_must_be_object() {
1509 assert!(validate_metadata(&json!("string")).is_err());
1510 assert!(validate_metadata(&json!([1, 2])).is_err());
1511 assert!(validate_metadata(&json!(42)).is_err());
1512 assert!(validate_metadata(&json!({"key": "value"})).is_ok());
1513 assert!(validate_metadata(&json!({})).is_ok());
1514 }
1515
1516 #[test]
1517 fn filterable_fields_reserved_name() {
1518 let fields = vec![json!({"name": "metadata", "type": "tag"})];
1519 let err = validate_filterable_fields(&fields).unwrap_err();
1520 assert!(err.to_string().contains("reserved"));
1521 }
1522
1523 #[test]
1524 fn filterable_fields_duplicate_name() {
1525 let fields = vec![
1526 json!({"name": "label", "type": "tag"}),
1527 json!({"name": "label", "type": "tag"}),
1528 ];
1529 let err = validate_filterable_fields(&fields).unwrap_err();
1530 assert!(err.to_string().contains("duplicate"));
1531 }
1532
1533 #[test]
1534 fn filterable_fields_invalid_type() {
1535 let fields = vec![
1536 json!({"name": "label", "type": "tag"}),
1537 json!({"name": "test", "type": "nothing"}),
1538 ];
1539 let err = validate_filterable_fields(&fields).unwrap_err();
1540 assert!(err.to_string().contains("invalid"));
1541 }
1542
1543 #[test]
1544 fn filterable_fields_valid() {
1545 let fields = vec![
1546 json!({"name": "label", "type": "tag"}),
1547 json!({"name": "score", "type": "numeric"}),
1548 ];
1549 assert!(validate_filterable_fields(&fields).is_ok());
1550 }
1551
1552 #[test]
1553 fn default_embeddings_cache_name() {
1554 let cache = EmbeddingsCache::default();
1555 assert_eq!(cache.config.name, "embedcache");
1556 assert!(cache.config.ttl_seconds.is_none());
1557 }
1558
1559 #[test]
1560 fn custom_embeddings_cache_config() {
1561 let config = CacheConfig::new("custom_cache", "redis://localhost:6379").with_ttl(60);
1562 let cache = EmbeddingsCache::new(config);
1563 assert_eq!(cache.config.name, "custom_cache");
1564 assert_eq!(cache.config.ttl_seconds, Some(60));
1565 }
1566
1567 #[test]
1568 fn semantic_cache_schema_respects_dtype() {
1569 use super::{VectorDataType, semantic_cache_schema};
1570
1571 let schema_f32 = semantic_cache_schema("test", 128, VectorDataType::Float32, &[]);
1572 let vec_field = schema_f32["fields"]
1573 .as_array()
1574 .unwrap()
1575 .iter()
1576 .find(|f| f["name"] == "prompt_vector")
1577 .unwrap();
1578 assert_eq!(vec_field["attrs"]["datatype"], "float32");
1579
1580 let schema_f64 = semantic_cache_schema("test", 128, VectorDataType::Float64, &[]);
1581 let vec_field = schema_f64["fields"]
1582 .as_array()
1583 .unwrap()
1584 .iter()
1585 .find(|f| f["name"] == "prompt_vector")
1586 .unwrap();
1587 assert_eq!(vec_field["attrs"]["datatype"], "float64");
1588
1589 let schema_bfloat16 = semantic_cache_schema("test", 128, VectorDataType::Bfloat16, &[]);
1590 let vec_field = schema_bfloat16["fields"]
1591 .as_array()
1592 .unwrap()
1593 .iter()
1594 .find(|f| f["name"] == "prompt_vector")
1595 .unwrap();
1596 assert_eq!(vec_field["attrs"]["datatype"], "bfloat16");
1597
1598 let schema_float16 = semantic_cache_schema("test", 128, VectorDataType::Float16, &[]);
1599 let vec_field = schema_float16["fields"]
1600 .as_array()
1601 .unwrap()
1602 .iter()
1603 .find(|f| f["name"] == "prompt_vector")
1604 .unwrap();
1605 assert_eq!(vec_field["attrs"]["datatype"], "float16");
1606 }
1607}