use std::{collections::HashMap, sync::Arc};
use chrono::Utc;
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Number, Value, json};
use sha2::{Digest, Sha256};
use crate::{
error::Result,
filter::FilterExpression,
index::{AsyncSearchIndex, QueryOutput, RedisConnectionInfo, SearchIndex},
query::{Vector, VectorRangeQuery},
schema::VectorDataType,
vectorizers::Vectorizer,
};
const SEMANTIC_ENTRY_ID_FIELD: &str = "entry_id";
const SEMANTIC_PROMPT_FIELD: &str = "prompt";
const SEMANTIC_RESPONSE_FIELD: &str = "response";
const SEMANTIC_VECTOR_FIELD: &str = "prompt_vector";
const SEMANTIC_INSERTED_AT_FIELD: &str = "inserted_at";
const SEMANTIC_UPDATED_AT_FIELD: &str = "updated_at";
const SEMANTIC_METADATA_FIELD: &str = "metadata";
const SEMANTIC_KEY_FIELD: &str = "key";
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub name: String,
pub connection: RedisConnectionInfo,
pub ttl_seconds: Option<u64>,
}
impl CacheConfig {
pub fn new(name: impl Into<String>, redis_url: impl Into<String>) -> Self {
Self {
name: name.into(),
connection: RedisConnectionInfo::new(redis_url),
ttl_seconds: None,
}
}
#[must_use]
pub fn with_ttl(mut self, ttl_seconds: u64) -> Self {
self.ttl_seconds = Some(ttl_seconds);
self
}
}
impl Default for CacheConfig {
fn default() -> Self {
Self::new("embedcache", "redis://127.0.0.1:6379")
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingCacheEntry {
pub entry_id: String,
pub content: String,
pub model_name: String,
pub embedding: Vec<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingCacheItem {
pub content: String,
pub model_name: String,
pub embedding: Vec<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
}
#[derive(Clone)]
pub struct SemanticCache {
pub config: CacheConfig,
pub distance_threshold: f32,
pub vector_dimensions: usize,
pub dtype: VectorDataType,
pub index: SearchIndex,
vectorizer: Option<Arc<dyn Vectorizer>>,
return_fields: Vec<String>,
}
impl SemanticCache {
pub fn new(
config: CacheConfig,
distance_threshold: f32,
vector_dimensions: usize,
) -> Result<Self> {
Self::with_options(
config,
distance_threshold,
vector_dimensions,
VectorDataType::Float32,
&[],
)
}
pub fn with_dtype(
config: CacheConfig,
distance_threshold: f32,
vector_dimensions: usize,
dtype: VectorDataType,
) -> Result<Self> {
Self::with_options(config, distance_threshold, vector_dimensions, dtype, &[])
}
pub fn with_filterable_fields(
config: CacheConfig,
distance_threshold: f32,
vector_dimensions: usize,
filterable_fields: &[Value],
) -> Result<Self> {
Self::with_options(
config,
distance_threshold,
vector_dimensions,
VectorDataType::Float32,
filterable_fields,
)
}
pub fn with_options(
config: CacheConfig,
distance_threshold: f32,
vector_dimensions: usize,
dtype: VectorDataType,
filterable_fields: &[Value],
) -> Result<Self> {
validate_distance_threshold(distance_threshold)?;
if vector_dimensions == 0 {
return Err(crate::Error::InvalidInput(
"vector_dimensions must be greater than zero".to_owned(),
));
}
validate_filterable_fields(filterable_fields)?;
let schema =
semantic_cache_schema(&config.name, vector_dimensions, dtype, filterable_fields);
let index = SearchIndex::from_json_value(schema, config.connection.redis_url.clone())?;
if !index.exists().unwrap_or(false) {
index.create_with_options(false, false)?;
}
Ok(Self {
config,
distance_threshold,
vector_dimensions,
dtype,
index,
vectorizer: None,
return_fields: default_semantic_return_fields(),
})
}
#[must_use]
pub fn with_vectorizer<V>(mut self, vectorizer: V) -> Self
where
V: Vectorizer + 'static,
{
self.vectorizer = Some(Arc::new(vectorizer));
self
}
#[cfg(feature = "hf-local")]
pub fn with_default_vectorizer(self) -> Result<Self> {
let vectorizer = crate::vectorizers::HuggingFaceTextVectorizer::new(Default::default())?;
Ok(self.with_vectorizer(vectorizer))
}
pub fn set_vectorizer<V>(&mut self, vectorizer: V)
where
V: Vectorizer + 'static,
{
self.vectorizer = Some(Arc::new(vectorizer));
}
pub fn ttl(&self) -> Option<u64> {
self.config.ttl_seconds
}
pub fn set_ttl(&mut self, ttl_seconds: Option<u64>) {
self.config.ttl_seconds = ttl_seconds;
}
pub fn set_threshold(&mut self, distance_threshold: f32) -> Result<()> {
validate_distance_threshold(distance_threshold)?;
self.distance_threshold = distance_threshold;
Ok(())
}
pub fn store(
&self,
prompt: &str,
response: &str,
vector: Option<&[f32]>,
metadata: Option<Value>,
filters: Option<Map<String, Value>>,
ttl_seconds: Option<u64>,
) -> Result<String> {
if let Some(metadata) = metadata.as_ref() {
validate_metadata(metadata)?;
}
let vector = self.resolve_vector(prompt, vector)?;
let timestamp = current_timestamp();
let entry_id = semantic_entry_id(prompt, filters.as_ref());
let mut record = Map::new();
record.insert(SEMANTIC_ENTRY_ID_FIELD.to_owned(), Value::String(entry_id));
record.insert(
SEMANTIC_PROMPT_FIELD.to_owned(),
Value::String(prompt.to_owned()),
);
record.insert(
SEMANTIC_RESPONSE_FIELD.to_owned(),
Value::String(response.to_owned()),
);
record.insert(
SEMANTIC_VECTOR_FIELD.to_owned(),
Value::Array(
vector
.iter()
.copied()
.map(|value| number_value(f64::from(value)))
.collect(),
),
);
record.insert(
SEMANTIC_INSERTED_AT_FIELD.to_owned(),
number_value(timestamp),
);
record.insert(
SEMANTIC_UPDATED_AT_FIELD.to_owned(),
number_value(timestamp),
);
if let Some(metadata) = metadata {
record.insert(SEMANTIC_METADATA_FIELD.to_owned(), metadata);
}
if let Some(filters) = filters {
for (key, value) in filters {
record.insert(key, value);
}
}
let keys = self.index.load(
&[Value::Object(record)],
SEMANTIC_ENTRY_ID_FIELD,
ttl_seconds
.or(self.config.ttl_seconds)
.map(|value| value as i64),
)?;
Ok(keys.into_iter().next().unwrap_or_default())
}
pub async fn astore(
&self,
prompt: &str,
response: &str,
vector: Option<&[f32]>,
metadata: Option<Value>,
filters: Option<Map<String, Value>>,
ttl_seconds: Option<u64>,
) -> Result<String> {
if let Some(metadata) = metadata.as_ref() {
validate_metadata(metadata)?;
}
let vector = self.resolve_vector(prompt, vector)?;
let timestamp = current_timestamp();
let entry_id = semantic_entry_id(prompt, filters.as_ref());
let mut record = Map::new();
record.insert(SEMANTIC_ENTRY_ID_FIELD.to_owned(), Value::String(entry_id));
record.insert(
SEMANTIC_PROMPT_FIELD.to_owned(),
Value::String(prompt.to_owned()),
);
record.insert(
SEMANTIC_RESPONSE_FIELD.to_owned(),
Value::String(response.to_owned()),
);
record.insert(
SEMANTIC_VECTOR_FIELD.to_owned(),
Value::Array(
vector
.iter()
.copied()
.map(|value| number_value(f64::from(value)))
.collect(),
),
);
record.insert(
SEMANTIC_INSERTED_AT_FIELD.to_owned(),
number_value(timestamp),
);
record.insert(
SEMANTIC_UPDATED_AT_FIELD.to_owned(),
number_value(timestamp),
);
if let Some(metadata) = metadata {
record.insert(SEMANTIC_METADATA_FIELD.to_owned(), metadata);
}
if let Some(filters) = filters {
for (key, value) in filters {
record.insert(key, value);
}
}
let keys = self
.async_index()
.load(
&[Value::Object(record)],
SEMANTIC_ENTRY_ID_FIELD,
ttl_seconds
.or(self.config.ttl_seconds)
.map(|value| value as i64),
)
.await?;
Ok(keys.into_iter().next().unwrap_or_default())
}
pub fn check(
&self,
prompt: Option<&str>,
vector: Option<&[f32]>,
num_results: usize,
return_fields: Option<&[&str]>,
filter_expression: Option<FilterExpression>,
distance_threshold: Option<f32>,
) -> Result<Vec<Map<String, Value>>> {
let vector = self.resolve_query_vector(prompt, vector)?;
let threshold = distance_threshold.unwrap_or(self.distance_threshold);
validate_distance_threshold(threshold)?;
let mut query = VectorRangeQuery::new(
Vector::new(vector.clone()),
SEMANTIC_VECTOR_FIELD,
threshold,
)
.paging(0, num_results)
.with_return_fields(self.return_fields.iter().map(String::as_str));
if let Some(filter_expression) = filter_expression {
query = query.with_filter(filter_expression);
}
let hits = process_semantic_hits(
query_output_documents(self.index.query(&query)?)?,
return_fields,
)?;
self.refresh_ttl_sync(&hits)?;
Ok(hits)
}
pub async fn acheck(
&self,
prompt: Option<&str>,
vector: Option<&[f32]>,
num_results: usize,
return_fields: Option<&[&str]>,
filter_expression: Option<FilterExpression>,
distance_threshold: Option<f32>,
) -> Result<Vec<Map<String, Value>>> {
let vector = self.resolve_query_vector(prompt, vector)?;
let threshold = distance_threshold.unwrap_or(self.distance_threshold);
validate_distance_threshold(threshold)?;
let mut query = VectorRangeQuery::new(
Vector::new(vector.clone()),
SEMANTIC_VECTOR_FIELD,
threshold,
)
.paging(0, num_results)
.with_return_fields(self.return_fields.iter().map(String::as_str));
if let Some(filter_expression) = filter_expression {
query = query.with_filter(filter_expression);
}
let hits = process_semantic_hits(
query_output_documents(self.async_index().query(&query).await?)?,
return_fields,
)?;
self.refresh_ttl_async(&hits).await?;
Ok(hits)
}
pub fn update(&self, key: &str, fields: Map<String, Value>) -> Result<()> {
let mapping = prepare_semantic_update_fields(fields)?;
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let mut cmd = redis::cmd("HSET");
cmd.arg(key);
for (field, value) in mapping {
cmd.arg(field).arg(value);
}
let _: usize = cmd.query(&mut connection)?;
self.expire_key(key, None)
}
pub async fn aupdate(&self, key: &str, fields: Map<String, Value>) -> Result<()> {
let mapping = prepare_semantic_update_fields(fields)?;
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let mut cmd = redis::cmd("HSET");
cmd.arg(key);
for (field, value) in mapping {
cmd.arg(field).arg(value);
}
let _: usize = cmd.query_async(&mut connection).await?;
self.aexpire_key(key, None).await
}
pub fn clear(&self) -> Result<usize> {
self.index.clear()
}
pub async fn aclear(&self) -> Result<usize> {
self.async_index().clear().await
}
pub fn delete(&self) -> Result<()> {
self.index.delete(true)
}
pub async fn adelete(&self) -> Result<()> {
self.async_index().delete(true).await
}
pub fn drop_ids(&self, ids: &[String]) -> Result<()> {
let keys = ids.iter().map(|id| self.index.key(id)).collect::<Vec<_>>();
self.index.drop_keys(&keys)?;
Ok(())
}
pub fn drop_keys(&self, keys: &[String]) -> Result<()> {
self.index.drop_keys(keys)?;
Ok(())
}
pub async fn adrop_ids(&self, ids: &[String]) -> Result<()> {
let keys = ids.iter().map(|id| self.index.key(id)).collect::<Vec<_>>();
self.async_index().drop_keys(&keys).await?;
Ok(())
}
pub async fn adrop_keys(&self, keys: &[String]) -> Result<()> {
self.async_index().drop_keys(keys).await?;
Ok(())
}
fn resolve_query_vector(
&self,
prompt: Option<&str>,
vector: Option<&[f32]>,
) -> Result<Vec<f32>> {
match (prompt, vector) {
(_, Some(vector)) => self.validate_vector(vector),
(Some(prompt), None) => self.resolve_vector(prompt, None),
(None, None) => Err(crate::Error::InvalidInput(
"either prompt or vector must be specified".to_owned(),
)),
}
}
fn resolve_vector(&self, prompt: &str, vector: Option<&[f32]>) -> Result<Vec<f32>> {
match vector {
Some(vector) => self.validate_vector(vector),
None => {
let Some(vectorizer) = &self.vectorizer else {
return Err(crate::Error::InvalidInput(
"a vector or configured vectorizer is required".to_owned(),
));
};
let vector = vectorizer.embed(prompt)?;
self.validate_vector(&vector)
}
}
}
fn validate_vector(&self, vector: &[f32]) -> Result<Vec<f32>> {
if vector.len() != self.vector_dimensions {
return Err(crate::Error::InvalidInput(format!(
"vector dimensions mismatch: expected {}, got {}",
self.vector_dimensions,
vector.len()
)));
}
Ok(vector.to_vec())
}
fn async_index(&self) -> AsyncSearchIndex {
AsyncSearchIndex::new(
self.index.schema().clone(),
self.config.connection.redis_url.clone(),
)
}
fn refresh_ttl_sync(&self, hits: &[Map<String, Value>]) -> Result<()> {
if self.config.ttl_seconds.is_none() {
return Ok(());
}
for hit in hits {
if let Some(key) = hit.get(SEMANTIC_KEY_FIELD).and_then(Value::as_str) {
self.expire_key(key, None)?;
}
}
Ok(())
}
async fn refresh_ttl_async(&self, hits: &[Map<String, Value>]) -> Result<()> {
if self.config.ttl_seconds.is_none() {
return Ok(());
}
for hit in hits {
if let Some(key) = hit.get(SEMANTIC_KEY_FIELD).and_then(Value::as_str) {
self.aexpire_key(key, None).await?;
}
}
Ok(())
}
fn expire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let _: bool = redis::cmd("EXPIRE")
.arg(key)
.arg(ttl_seconds)
.query(&mut connection)?;
}
Ok(())
}
async fn aexpire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let _: bool = redis::cmd("EXPIRE")
.arg(key)
.arg(ttl_seconds)
.query_async(&mut connection)
.await?;
}
Ok(())
}
}
impl std::fmt::Debug for SemanticCache {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("SemanticCache")
.field("config", &self.config)
.field("distance_threshold", &self.distance_threshold)
.field("vector_dimensions", &self.vector_dimensions)
.field("index_name", &self.index.name())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingsCache {
pub config: CacheConfig,
}
impl Default for EmbeddingsCache {
fn default() -> Self {
Self::new(CacheConfig::default())
}
}
impl EmbeddingsCache {
pub fn new(config: CacheConfig) -> Self {
Self { config }
}
pub fn make_entry_id(&self, content: &str, model_name: &str) -> String {
hashify(&format!("{content}:{model_name}"))
}
pub fn make_cache_key(&self, content: &str, model_name: &str) -> String {
let entry_id = self.make_entry_id(content, model_name);
self.key_for_entry(&entry_id)
}
pub fn get(&self, content: &str, model_name: &str) -> Result<Option<EmbeddingCacheEntry>> {
let key = self.make_cache_key(content, model_name);
self.get_by_key(&key)
}
pub fn get_by_key(&self, key: &str) -> Result<Option<EmbeddingCacheEntry>> {
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let data: HashMap<String, String> =
redis::cmd("HGETALL").arg(key).query(&mut connection)?;
if data.is_empty() {
return Ok(None);
}
self.expire_key(key, None)?;
parse_entry(data)
}
pub fn mget<I, S>(
&self,
contents: I,
model_name: &str,
) -> Result<Vec<Option<EmbeddingCacheEntry>>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = contents
.into_iter()
.map(|content| self.make_cache_key(content.as_ref(), model_name))
.collect::<Vec<_>>();
self.mget_by_keys(keys)
}
pub fn mget_by_keys<I, S>(&self, keys: I) -> Result<Vec<Option<EmbeddingCacheEntry>>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = collect_strings(keys);
if keys.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(keys.len());
for key in &keys {
results.push(self.get_by_key(key)?);
}
Ok(results)
}
pub fn set(
&self,
content: &str,
model_name: &str,
embedding: &[f32],
metadata: Option<Value>,
ttl_seconds: Option<u64>,
) -> Result<String> {
let entry = self.prepare_entry(content, model_name, embedding, metadata);
let key = self.key_for_entry(&entry.entry_id);
self.write_entry(&key, &entry)?;
self.expire_key(&key, ttl_seconds)?;
Ok(key)
}
pub fn mset(
&self,
items: &[EmbeddingCacheItem],
ttl_seconds: Option<u64>,
) -> Result<Vec<String>> {
let mut keys = Vec::with_capacity(items.len());
for item in items {
let key = self.set(
&item.content,
&item.model_name,
&item.embedding,
item.metadata.clone(),
ttl_seconds,
)?;
keys.push(key);
}
Ok(keys)
}
pub fn exists(&self, content: &str, model_name: &str) -> Result<bool> {
let key = self.make_cache_key(content, model_name);
self.exists_by_key(&key)
}
pub fn exists_by_key(&self, key: &str) -> Result<bool> {
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let exists: u64 = redis::cmd("EXISTS").arg(key).query(&mut connection)?;
Ok(exists > 0)
}
pub fn mexists<I, S>(&self, contents: I, model_name: &str) -> Result<Vec<bool>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = contents
.into_iter()
.map(|content| self.make_cache_key(content.as_ref(), model_name))
.collect::<Vec<_>>();
self.mexists_by_keys(keys)
}
pub fn mexists_by_keys<I, S>(&self, keys: I) -> Result<Vec<bool>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = collect_strings(keys);
if keys.is_empty() {
return Ok(Vec::new());
}
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let mut results = Vec::with_capacity(keys.len());
for key in keys {
let exists: u64 = redis::cmd("EXISTS").arg(key).query(&mut connection)?;
results.push(exists > 0);
}
Ok(results)
}
pub fn drop(&self, content: &str, model_name: &str) -> Result<()> {
let key = self.make_cache_key(content, model_name);
self.drop_by_key(&key)
}
pub fn drop_by_key(&self, key: &str) -> Result<()> {
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let _: usize = redis::cmd("DEL").arg(key).query(&mut connection)?;
Ok(())
}
pub fn mdrop<I, S>(&self, contents: I, model_name: &str) -> Result<()>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = contents
.into_iter()
.map(|content| self.make_cache_key(content.as_ref(), model_name))
.collect::<Vec<_>>();
self.mdrop_by_keys(keys)
}
pub fn mdrop_by_keys<I, S>(&self, keys: I) -> Result<()>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = collect_strings(keys);
if keys.is_empty() {
return Ok(());
}
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let _: usize = redis::cmd("DEL").arg(keys).query(&mut connection)?;
Ok(())
}
pub fn clear(&self) -> Result<usize> {
let keys = self.all_keys()?;
if keys.is_empty() {
return Ok(0);
}
let count = keys.len();
self.mdrop_by_keys(keys)?;
Ok(count)
}
pub async fn aget(
&self,
content: &str,
model_name: &str,
) -> Result<Option<EmbeddingCacheEntry>> {
let key = self.make_cache_key(content, model_name);
self.aget_by_key(&key).await
}
pub async fn aget_by_key(&self, key: &str) -> Result<Option<EmbeddingCacheEntry>> {
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let data: HashMap<String, String> = redis::cmd("HGETALL")
.arg(key)
.query_async(&mut connection)
.await?;
if data.is_empty() {
return Ok(None);
}
self.aexpire_key(key, None).await?;
parse_entry(data)
}
pub async fn amget<I, S>(
&self,
contents: I,
model_name: &str,
) -> Result<Vec<Option<EmbeddingCacheEntry>>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = contents
.into_iter()
.map(|content| self.make_cache_key(content.as_ref(), model_name))
.collect::<Vec<_>>();
self.amget_by_keys(keys).await
}
pub async fn amget_by_keys<I, S>(&self, keys: I) -> Result<Vec<Option<EmbeddingCacheEntry>>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = collect_strings(keys);
if keys.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(keys.len());
for key in &keys {
results.push(self.aget_by_key(key).await?);
}
Ok(results)
}
pub async fn aset(
&self,
content: &str,
model_name: &str,
embedding: &[f32],
metadata: Option<Value>,
ttl_seconds: Option<u64>,
) -> Result<String> {
let entry = self.prepare_entry(content, model_name, embedding, metadata);
let key = self.key_for_entry(&entry.entry_id);
self.awrite_entry(&key, &entry).await?;
self.aexpire_key(&key, ttl_seconds).await?;
Ok(key)
}
pub async fn amset(
&self,
items: &[EmbeddingCacheItem],
ttl_seconds: Option<u64>,
) -> Result<Vec<String>> {
let mut keys = Vec::with_capacity(items.len());
for item in items {
let key = self
.aset(
&item.content,
&item.model_name,
&item.embedding,
item.metadata.clone(),
ttl_seconds,
)
.await?;
keys.push(key);
}
Ok(keys)
}
pub async fn aexists(&self, content: &str, model_name: &str) -> Result<bool> {
let key = self.make_cache_key(content, model_name);
self.aexists_by_key(&key).await
}
pub async fn aexists_by_key(&self, key: &str) -> Result<bool> {
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
Ok(connection.exists(key).await?)
}
pub async fn amexists<I, S>(&self, contents: I, model_name: &str) -> Result<Vec<bool>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = contents
.into_iter()
.map(|content| self.make_cache_key(content.as_ref(), model_name))
.collect::<Vec<_>>();
self.amexists_by_keys(keys).await
}
pub async fn amexists_by_keys<I, S>(&self, keys: I) -> Result<Vec<bool>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = collect_strings(keys);
if keys.is_empty() {
return Ok(Vec::new());
}
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let mut results = Vec::with_capacity(keys.len());
for key in keys {
results.push(connection.exists(key).await?);
}
Ok(results)
}
pub async fn adrop(&self, content: &str, model_name: &str) -> Result<()> {
let key = self.make_cache_key(content, model_name);
self.adrop_by_key(&key).await
}
pub async fn adrop_by_key(&self, key: &str) -> Result<()> {
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let _: usize = connection.del(key).await?;
Ok(())
}
pub async fn amdrop<I, S>(&self, contents: I, model_name: &str) -> Result<()>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = contents
.into_iter()
.map(|content| self.make_cache_key(content.as_ref(), model_name))
.collect::<Vec<_>>();
self.amdrop_by_keys(keys).await
}
pub async fn amdrop_by_keys<I, S>(&self, keys: I) -> Result<()>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let keys = collect_strings(keys);
if keys.is_empty() {
return Ok(());
}
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let _: usize = connection.del(keys).await?;
Ok(())
}
pub async fn aclear(&self) -> Result<usize> {
let keys = self.aall_keys().await?;
if keys.is_empty() {
return Ok(0);
}
let count = keys.len();
self.amdrop_by_keys(keys).await?;
Ok(count)
}
fn prepare_entry(
&self,
content: &str,
model_name: &str,
embedding: &[f32],
metadata: Option<Value>,
) -> EmbeddingCacheEntry {
EmbeddingCacheEntry {
entry_id: self.make_entry_id(content, model_name),
content: content.to_owned(),
model_name: model_name.to_owned(),
embedding: embedding.to_vec(),
metadata,
}
}
fn write_entry(&self, key: &str, entry: &EmbeddingCacheEntry) -> Result<()> {
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let mut cmd = redis::cmd("HSET");
cmd.arg(key)
.arg("entry_id")
.arg(&entry.entry_id)
.arg("content")
.arg(&entry.content)
.arg("model_name")
.arg(&entry.model_name)
.arg("embedding")
.arg(serde_json::to_string(&entry.embedding)?);
if let Some(metadata) = &entry.metadata {
cmd.arg("metadata").arg(serde_json::to_string(metadata)?);
}
let _: usize = cmd.query(&mut connection)?;
Ok(())
}
async fn awrite_entry(&self, key: &str, entry: &EmbeddingCacheEntry) -> Result<()> {
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let mut cmd = redis::cmd("HSET");
cmd.arg(key)
.arg("entry_id")
.arg(&entry.entry_id)
.arg("content")
.arg(&entry.content)
.arg("model_name")
.arg(&entry.model_name)
.arg("embedding")
.arg(serde_json::to_string(&entry.embedding)?);
if let Some(metadata) = &entry.metadata {
cmd.arg("metadata").arg(serde_json::to_string(metadata)?);
}
let _: usize = cmd.query_async(&mut connection).await?;
Ok(())
}
fn expire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let _: bool = redis::cmd("EXPIRE")
.arg(key)
.arg(ttl_seconds)
.query(&mut connection)?;
}
Ok(())
}
async fn aexpire_key(&self, key: &str, ttl_override: Option<u64>) -> Result<()> {
if let Some(ttl_seconds) = ttl_override.or(self.config.ttl_seconds) {
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let _: bool = redis::cmd("EXPIRE")
.arg(key)
.arg(ttl_seconds)
.query_async(&mut connection)
.await?;
}
Ok(())
}
fn all_keys(&self) -> Result<Vec<String>> {
let client = self.config.connection.client()?;
let mut connection = client.get_connection()?;
let keys: Vec<String> = redis::cmd("KEYS")
.arg(format!("{}:*", self.config.name))
.query(&mut connection)?;
Ok(keys)
}
async fn aall_keys(&self) -> Result<Vec<String>> {
let client = self.config.connection.client()?;
let mut connection = client.get_multiplexed_async_connection().await?;
let keys: Vec<String> = redis::cmd("KEYS")
.arg(format!("{}:*", self.config.name))
.query_async(&mut connection)
.await?;
Ok(keys)
}
fn key_for_entry(&self, entry_id: &str) -> String {
format!("{}:{entry_id}", self.config.name)
}
}
fn collect_strings<I, S>(values: I) -> Vec<String>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
values
.into_iter()
.map(|value| value.as_ref().to_owned())
.collect()
}
fn parse_entry(data: HashMap<String, String>) -> Result<Option<EmbeddingCacheEntry>> {
if data.is_empty() {
return Ok(None);
}
let entry = EmbeddingCacheEntry {
entry_id: data.get("entry_id").cloned().unwrap_or_default(),
content: data.get("content").cloned().unwrap_or_default(),
model_name: data.get("model_name").cloned().unwrap_or_default(),
embedding: match data.get("embedding") {
Some(value) => serde_json::from_str::<Vec<f32>>(value)?,
None => Vec::new(),
},
metadata: data
.get("metadata")
.map(|value| serde_json::from_str::<Value>(value))
.transpose()?,
};
Ok(Some(entry))
}
fn hashify(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
let digest = hasher.finalize();
let mut output = String::with_capacity(digest.len() * 2);
for byte in digest {
use std::fmt::Write as _;
let _ = write!(&mut output, "{byte:02x}");
}
output
}
fn semantic_cache_schema(
name: &str,
vector_dimensions: usize,
dtype: VectorDataType,
filterable_fields: &[Value],
) -> Value {
let mut fields = vec![
json!({ "name": SEMANTIC_ENTRY_ID_FIELD, "type": "tag" }),
json!({ "name": SEMANTIC_PROMPT_FIELD, "type": "text" }),
json!({ "name": SEMANTIC_RESPONSE_FIELD, "type": "text" }),
json!({ "name": SEMANTIC_INSERTED_AT_FIELD, "type": "numeric" }),
json!({ "name": SEMANTIC_UPDATED_AT_FIELD, "type": "numeric" }),
json!({ "name": SEMANTIC_METADATA_FIELD, "type": "text" }),
json!({
"name": SEMANTIC_VECTOR_FIELD,
"type": "vector",
"attrs": {
"algorithm": "flat",
"dims": vector_dimensions,
"datatype": dtype.as_str(),
"distance_metric": "cosine"
}
}),
];
fields.extend(filterable_fields.iter().cloned());
json!({
"index": {
"name": name,
"prefix": name,
"storage_type": "hash",
},
"fields": fields,
})
}
fn default_semantic_return_fields() -> Vec<String> {
vec![
SEMANTIC_ENTRY_ID_FIELD.to_owned(),
SEMANTIC_PROMPT_FIELD.to_owned(),
SEMANTIC_RESPONSE_FIELD.to_owned(),
"vector_distance".to_owned(),
SEMANTIC_INSERTED_AT_FIELD.to_owned(),
SEMANTIC_UPDATED_AT_FIELD.to_owned(),
SEMANTIC_METADATA_FIELD.to_owned(),
]
}
fn current_timestamp() -> f64 {
Utc::now().timestamp_millis() as f64 / 1000.0
}
fn semantic_entry_id(prompt: &str, filters: Option<&Map<String, Value>>) -> String {
if let Some(filters) = filters {
let mut parts = filters
.iter()
.map(|(key, value)| format!("{key}{}", value_to_hash_string(value)))
.collect::<Vec<_>>();
parts.sort();
hashify(&format!("{prompt}{}", parts.join("")))
} else {
hashify(prompt)
}
}
fn value_to_hash_string(value: &Value) -> String {
match value {
Value::Null => "null".to_owned(),
Value::Bool(value) => value.to_string(),
Value::Number(value) => value.to_string(),
Value::String(value) => value.clone(),
Value::Array(_) | Value::Object(_) => serde_json::to_string(value).unwrap_or_default(),
}
}
const RESERVED_SEMANTIC_FIELDS: &[&str] = &[
SEMANTIC_ENTRY_ID_FIELD,
SEMANTIC_PROMPT_FIELD,
SEMANTIC_RESPONSE_FIELD,
SEMANTIC_VECTOR_FIELD,
SEMANTIC_INSERTED_AT_FIELD,
SEMANTIC_UPDATED_AT_FIELD,
SEMANTIC_METADATA_FIELD,
SEMANTIC_KEY_FIELD,
"vector_distance",
];
fn validate_filterable_fields(fields: &[Value]) -> Result<()> {
let mut seen = std::collections::HashSet::new();
for field in fields {
let name = field
.get("name")
.and_then(Value::as_str)
.unwrap_or_default();
let field_type = field
.get("type")
.and_then(Value::as_str)
.unwrap_or_default();
if name.is_empty() {
return Err(crate::Error::InvalidInput(
"filterable field must have a non-empty 'name'".to_owned(),
));
}
if RESERVED_SEMANTIC_FIELDS.contains(&name) {
return Err(crate::Error::InvalidInput(format!(
"{name} is a reserved field name for the semantic cache schema"
)));
}
if !seen.insert(name.to_owned()) {
return Err(crate::Error::InvalidInput(format!(
"duplicate field name: {name}. Field names must be unique"
)));
}
if !matches!(field_type, "tag" | "text" | "numeric" | "geo") {
return Err(crate::Error::InvalidInput(format!(
"invalid filterable field type: '{field_type}' for field '{name}'"
)));
}
}
Ok(())
}
fn validate_distance_threshold(distance_threshold: f32) -> Result<()> {
if !(0.0..=2.0).contains(&distance_threshold) {
return Err(crate::Error::InvalidInput(format!(
"distance threshold must be between 0 and 2, got {distance_threshold}"
)));
}
Ok(())
}
fn validate_metadata(metadata: &Value) -> Result<()> {
if !metadata.is_object() {
return Err(crate::Error::InvalidInput(
"metadata must be a JSON object".to_owned(),
));
}
Ok(())
}
fn query_output_documents(output: QueryOutput) -> Result<Vec<Map<String, Value>>> {
match output {
QueryOutput::Documents(documents) => Ok(documents),
QueryOutput::Count(_) => Err(crate::Error::InvalidInput(
"semantic cache queries must return documents".to_owned(),
)),
}
}
fn process_semantic_hits(
documents: Vec<Map<String, Value>>,
return_fields: Option<&[&str]>,
) -> Result<Vec<Map<String, Value>>> {
let selected = return_fields.map(|fields| {
fields
.iter()
.map(|field| (*field).to_owned())
.collect::<std::collections::HashSet<_>>()
});
let mut hits = Vec::with_capacity(documents.len());
for mut document in documents {
let key = document
.remove("id")
.unwrap_or_else(|| Value::String(String::new()));
let mut hit = Map::new();
hit.insert(SEMANTIC_KEY_FIELD.to_owned(), key);
for (field, value) in document {
let include = selected
.as_ref()
.is_none_or(|fields| fields.contains(&field));
if !include {
continue;
}
hit.insert(field.clone(), normalize_semantic_value(&field, value)?);
}
hits.push(hit);
}
Ok(hits)
}
fn normalize_semantic_value(field: &str, value: Value) -> Result<Value> {
match (field, value) {
(SEMANTIC_METADATA_FIELD, Value::String(value)) => {
Ok(serde_json::from_str(&value).unwrap_or(Value::String(value)))
}
(
"vector_distance" | SEMANTIC_INSERTED_AT_FIELD | SEMANTIC_UPDATED_AT_FIELD,
Value::String(value),
) => {
let parsed = value.parse::<f64>().map_err(|_| {
crate::Error::InvalidInput(format!("could not parse numeric field '{field}'"))
})?;
Ok(number_value(parsed))
}
(_, value) => Ok(value),
}
}
fn prepare_semantic_update_fields(fields: Map<String, Value>) -> Result<Vec<(String, String)>> {
let mut mapping = Vec::with_capacity(fields.len() + 1);
for (field, value) in fields {
if field == SEMANTIC_VECTOR_FIELD {
return Err(crate::Error::InvalidInput(
"updating the stored vector is not supported yet".to_owned(),
));
}
if field == SEMANTIC_METADATA_FIELD {
validate_metadata(&value)?;
}
let serialized = match value {
Value::Null => "null".to_owned(),
Value::Bool(value) => value.to_string(),
Value::Number(value) => value.to_string(),
Value::String(value) => value,
Value::Array(_) | Value::Object(_) => serde_json::to_string(&value)?,
};
mapping.push((field, serialized));
}
mapping.push((
SEMANTIC_UPDATED_AT_FIELD.to_owned(),
current_timestamp().to_string(),
));
Ok(mapping)
}
fn number_value(value: f64) -> Value {
Number::from_f64(value)
.map(Value::Number)
.unwrap_or(Value::Null)
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{
CacheConfig, EmbeddingsCache, hashify, validate_distance_threshold,
validate_filterable_fields, validate_metadata,
};
#[test]
fn hashify_matches_expected_sha256() {
assert_eq!(
hashify("Hello world:text-embedding-ada-002"),
"368dacc611e96e4189a9809faaca1a70b3c3306352bbcfc9ab6291359a5dfca0"
);
}
#[test]
fn cache_key_is_stable() {
let cache = EmbeddingsCache::new(CacheConfig::default());
let key = cache.make_cache_key("Hello world", "text-embedding-ada-002");
assert_eq!(
key,
"embedcache:368dacc611e96e4189a9809faaca1a70b3c3306352bbcfc9ab6291359a5dfca0"
);
}
#[test]
fn entry_id_is_deterministic() {
let cache = EmbeddingsCache::new(CacheConfig::default());
let id1 = cache.make_entry_id("Hello world", "text-embedding-ada-002");
let id2 = cache.make_entry_id("Hello world", "text-embedding-ada-002");
assert_eq!(id1, id2);
let different = cache.make_entry_id("Different text", "text-embedding-ada-002");
assert_ne!(id1, different);
}
#[test]
fn entry_id_different_inputs_differ() {
let cache = EmbeddingsCache::new(CacheConfig::default());
let id_a = cache.make_entry_id("What is machine learning?", "text-embedding-ada-002");
let id_b = cache.make_entry_id("How do neural networks work?", "text-embedding-ada-002");
assert_ne!(id_a, id_b);
}
#[test]
fn cache_key_includes_cache_name() {
let cache_a = EmbeddingsCache::new(CacheConfig::new("cache_a", "redis://localhost:6379"));
let cache_b = EmbeddingsCache::new(CacheConfig::new("cache_b", "redis://localhost:6379"));
let key_a = cache_a.make_cache_key("hello", "model");
let key_b = cache_b.make_cache_key("hello", "model");
assert!(key_a.starts_with("cache_a:"));
assert!(key_b.starts_with("cache_b:"));
assert_ne!(key_a, key_b);
}
#[test]
fn distance_threshold_out_of_range() {
assert!(validate_distance_threshold(-1.0).is_err());
assert!(validate_distance_threshold(2.5).is_err());
assert!(validate_distance_threshold(0.0).is_ok());
assert!(validate_distance_threshold(1.0).is_ok());
assert!(validate_distance_threshold(2.0).is_ok());
}
#[test]
fn metadata_must_be_object() {
assert!(validate_metadata(&json!("string")).is_err());
assert!(validate_metadata(&json!([1, 2])).is_err());
assert!(validate_metadata(&json!(42)).is_err());
assert!(validate_metadata(&json!({"key": "value"})).is_ok());
assert!(validate_metadata(&json!({})).is_ok());
}
#[test]
fn filterable_fields_reserved_name() {
let fields = vec![json!({"name": "metadata", "type": "tag"})];
let err = validate_filterable_fields(&fields).unwrap_err();
assert!(err.to_string().contains("reserved"));
}
#[test]
fn filterable_fields_duplicate_name() {
let fields = vec![
json!({"name": "label", "type": "tag"}),
json!({"name": "label", "type": "tag"}),
];
let err = validate_filterable_fields(&fields).unwrap_err();
assert!(err.to_string().contains("duplicate"));
}
#[test]
fn filterable_fields_invalid_type() {
let fields = vec![
json!({"name": "label", "type": "tag"}),
json!({"name": "test", "type": "nothing"}),
];
let err = validate_filterable_fields(&fields).unwrap_err();
assert!(err.to_string().contains("invalid"));
}
#[test]
fn filterable_fields_valid() {
let fields = vec![
json!({"name": "label", "type": "tag"}),
json!({"name": "score", "type": "numeric"}),
];
assert!(validate_filterable_fields(&fields).is_ok());
}
#[test]
fn default_embeddings_cache_name() {
let cache = EmbeddingsCache::default();
assert_eq!(cache.config.name, "embedcache");
assert!(cache.config.ttl_seconds.is_none());
}
#[test]
fn custom_embeddings_cache_config() {
let config = CacheConfig::new("custom_cache", "redis://localhost:6379").with_ttl(60);
let cache = EmbeddingsCache::new(config);
assert_eq!(cache.config.name, "custom_cache");
assert_eq!(cache.config.ttl_seconds, Some(60));
}
#[test]
fn semantic_cache_schema_respects_dtype() {
use super::{VectorDataType, semantic_cache_schema};
let schema_f32 = semantic_cache_schema("test", 128, VectorDataType::Float32, &[]);
let vec_field = schema_f32["fields"]
.as_array()
.unwrap()
.iter()
.find(|f| f["name"] == "prompt_vector")
.unwrap();
assert_eq!(vec_field["attrs"]["datatype"], "float32");
let schema_f64 = semantic_cache_schema("test", 128, VectorDataType::Float64, &[]);
let vec_field = schema_f64["fields"]
.as_array()
.unwrap()
.iter()
.find(|f| f["name"] == "prompt_vector")
.unwrap();
assert_eq!(vec_field["attrs"]["datatype"], "float64");
let schema_bfloat16 = semantic_cache_schema("test", 128, VectorDataType::Bfloat16, &[]);
let vec_field = schema_bfloat16["fields"]
.as_array()
.unwrap()
.iter()
.find(|f| f["name"] == "prompt_vector")
.unwrap();
assert_eq!(vec_field["attrs"]["datatype"], "bfloat16");
let schema_float16 = semantic_cache_schema("test", 128, VectorDataType::Float16, &[]);
let vec_field = schema_float16["fields"]
.as_array()
.unwrap()
.iter()
.find(|f| f["name"] == "prompt_vector")
.unwrap();
assert_eq!(vec_field["attrs"]["datatype"], "float16");
}
}