use rig_core::vector_store::request::{Filter as CoreFilter, FilterError, SearchFilter};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
pub enum RedisValue {
Number(f64),
String(String),
Bool(bool),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RedisNumber(f64);
impl RedisNumber {
pub fn new(value: f64) -> Result<Self, FilterError> {
if value.is_finite() {
Ok(Self(value))
} else {
Err(FilterError::Expected {
expected: "finite numeric value for Redis numeric filter".into(),
got: value.to_string(),
})
}
}
fn get(self) -> f64 {
self.0
}
}
impl TryFrom<f64> for RedisNumber {
type Error = FilterError;
fn try_from(value: f64) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<i64> for RedisNumber {
fn from(value: i64) -> Self {
Self(value as f64)
}
}
impl From<u64> for RedisNumber {
fn from(value: u64) -> Self {
Self(value as f64)
}
}
fn numeric_bound(value: RedisValue, operation: &'static str) -> Result<RedisNumber, FilterError> {
match value {
RedisValue::Number(n) if n.is_finite() => Ok(RedisNumber(n)),
RedisValue::Number(n) => Err(FilterError::Expected {
expected: format!("finite numeric value for Redis {operation} filter"),
got: n.to_string(),
}),
other => Err(FilterError::Expected {
expected: format!("numeric value for Redis {operation} filter"),
got: format!("{other:?}"),
}),
}
}
fn numeric_eq_filter(key: impl AsRef<str>, value: RedisNumber) -> Filter {
let value = value.get();
Filter(format!("@{}:[{value} {value}]", field_name(key)))
}
fn gt_number_filter(key: impl AsRef<str>, value: RedisNumber) -> Filter {
let value = value.get();
Filter(format!("@{}:[({value} +inf]", field_name(key)))
}
fn lt_number_filter(key: impl AsRef<str>, value: RedisNumber) -> Filter {
let value = value.get();
Filter(format!("@{}:[-inf ({value}]", field_name(key)))
}
fn gte_number_filter(key: impl AsRef<str>, value: RedisNumber) -> Filter {
let value = value.get();
Filter(format!("@{}:[{value} +inf]", field_name(key)))
}
fn lte_number_filter(key: impl AsRef<str>, value: RedisNumber) -> Filter {
let value = value.get();
Filter(format!("@{}:[-inf {value}]", field_name(key)))
}
fn field_name(key: impl AsRef<str>) -> String {
key.as_ref()
.split('.')
.map(escape_field_segment)
.collect::<Vec<_>>()
.join(".")
}
fn escape_field_segment(segment: &str) -> String {
let mut escaped = String::with_capacity(segment.len());
for ch in segment.chars() {
if ch.is_ascii_alphanumeric() || ch == '_' {
escaped.push(ch);
} else {
escaped.push('\\');
escaped.push(ch);
}
}
escaped
}
pub fn escape_tag_value(value: &str) -> String {
let mut escaped = String::with_capacity(value.len());
for ch in value.chars() {
if matches!(
ch,
'\\' | ' '
| ','
| '.'
| '<'
| '>'
| '{'
| '}'
| '['
| ']'
| '"'
| '\''
| ':'
| ';'
| '!'
| '@'
| '#'
| '$'
| '%'
| '^'
| '&'
| '*'
| '('
| ')'
| '-'
| '+'
| '='
| '~'
| '|'
| '/'
) {
escaped.push('\\');
}
escaped.push(ch);
}
escaped
}
pub fn escape_text_value(value: &str) -> String {
let mut escaped = String::with_capacity(value.len());
for ch in value.chars() {
if matches!(
ch,
'\\' | '<'
| '>'
| '{'
| '}'
| '['
| ']'
| '"'
| '\''
| ':'
| ';'
| '!'
| '@'
| '#'
| '$'
| '%'
| '^'
| '&'
| '*'
| '('
| ')'
| '-'
| '+'
| '='
| '~'
| '|'
| '/'
) {
escaped.push('\\');
}
escaped.push(ch);
}
escaped
}
impl From<i64> for RedisValue {
fn from(value: i64) -> Self {
Self::Number(value as f64)
}
}
impl From<u64> for RedisValue {
fn from(value: u64) -> Self {
Self::Number(value as f64)
}
}
impl From<f64> for RedisValue {
fn from(value: f64) -> Self {
Self::Number(value)
}
}
impl From<bool> for RedisValue {
fn from(value: bool) -> Self {
Self::Bool(value)
}
}
impl From<String> for RedisValue {
fn from(value: String) -> Self {
Self::String(value)
}
}
impl From<&str> for RedisValue {
fn from(value: &str) -> Self {
Self::String(value.to_owned())
}
}
impl TryFrom<serde_json::Value> for RedisValue {
type Error = FilterError;
fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
match value {
serde_json::Value::Bool(b) => Ok(RedisValue::Bool(b)),
serde_json::Value::Number(n) => {
let num = n.as_f64().ok_or_else(|| FilterError::Expected {
expected: "Valid 64-bit float".into(),
got: "Invalid 64-bit float".into(),
})?;
Ok(RedisValue::Number(num))
}
serde_json::Value::String(s) => Ok(RedisValue::String(s)),
serde_json::Value::Null
| serde_json::Value::Array(_)
| serde_json::Value::Object(_) => Err(FilterError::TypeError(
"Redis filter does not currently support null values, arrays or objects".into(),
)),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Filter(String);
impl SearchFilter for Filter {
type Value = RedisNumber;
fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
numeric_eq_filter(key, value)
}
fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
gt_number_filter(key, value)
}
fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
lt_number_filter(key, value)
}
fn and(self, rhs: Self) -> Self {
self.and(rhs)
}
fn or(self, rhs: Self) -> Self {
self.or(rhs)
}
}
impl Filter {
pub fn eq(key: impl AsRef<str>, value: impl Into<RedisValue>) -> Result<Self, FilterError> {
let key = field_name(key);
let filter = match value.into() {
RedisValue::Number(n) => numeric_eq_filter(key, RedisNumber::new(n)?),
RedisValue::String(ref s) => Self(format!("@{key}:{{{}}}", escape_tag_value(s))),
RedisValue::Bool(b) => {
let v = if b { "1" } else { "0" };
Self(format!("@{key}:{{{v}}}"))
}
};
Ok(filter)
}
pub fn gt(key: impl AsRef<str>, value: impl Into<RedisValue>) -> Result<Self, FilterError> {
let value = numeric_bound(value.into(), "greater-than")?;
Ok(gt_number_filter(key, value))
}
pub fn lt(key: impl AsRef<str>, value: impl Into<RedisValue>) -> Result<Self, FilterError> {
let value = numeric_bound(value.into(), "less-than")?;
Ok(lt_number_filter(key, value))
}
#[allow(clippy::should_implement_trait)]
pub fn not(self) -> Self {
Self(format!("-{}", self.0))
}
pub fn gte(key: impl AsRef<str>, value: impl Into<RedisValue>) -> Result<Self, FilterError> {
let value = numeric_bound(value.into(), "greater-than-or-equal")?;
Ok(gte_number_filter(key, value))
}
pub fn lte(key: impl AsRef<str>, value: impl Into<RedisValue>) -> Result<Self, FilterError> {
let value = numeric_bound(value.into(), "less-than-or-equal")?;
Ok(lte_number_filter(key, value))
}
pub fn and(self, rhs: Self) -> Self {
Self(format!("({} {})", self.0, rhs.0))
}
pub fn or(self, rhs: Self) -> Self {
Self(format!("({} | {})", self.0, rhs.0))
}
pub fn range(key: impl AsRef<str>, min: f64, max: f64) -> Result<Self, FilterError> {
let min = RedisNumber::new(min)?.get();
let max = RedisNumber::new(max)?.get();
Ok(Self(format!("@{}:[{} {}]", field_name(key), min, max)))
}
pub fn range_exclusive(key: impl AsRef<str>, min: f64, max: f64) -> Result<Self, FilterError> {
let min = RedisNumber::new(min)?.get();
let max = RedisNumber::new(max)?.get();
Ok(Self(format!("@{}:[({} ({}]", field_name(key), min, max)))
}
pub fn tag_in(key: impl AsRef<str>, values: Vec<String>) -> Self {
if values.is_empty() {
return Self::raw("*");
}
let tags = values
.iter()
.map(|value| escape_tag_value(value))
.collect::<Vec<_>>()
.join(" | ");
Self(format!("@{}:{{{}}}", field_name(key), tags))
}
pub fn text_contains(key: impl AsRef<str>, text: impl AsRef<str>) -> Self {
Self(format!(
"@{}:({})",
field_name(key),
escape_text_value(text.as_ref())
))
}
pub fn text_phrase(key: impl AsRef<str>, phrase: impl AsRef<str>) -> Self {
Self(format!(
"@{}:\"{}\"",
field_name(key),
escape_text_value(phrase.as_ref())
))
}
pub fn raw(query: impl Into<String>) -> Self {
Self(query.into())
}
pub fn into_inner(self) -> String {
self.0
}
}
impl TryFrom<CoreFilter<serde_json::Value>> for Filter {
type Error = FilterError;
fn try_from(value: CoreFilter<serde_json::Value>) -> Result<Self, Self::Error> {
let filter = match value {
CoreFilter::Eq(k, val) => Filter::eq(k, RedisValue::try_from(val)?)?,
CoreFilter::Gt(k, val) => Filter::gt(k, RedisValue::try_from(val)?)?,
CoreFilter::Lt(k, val) => Filter::lt(k, RedisValue::try_from(val)?)?,
CoreFilter::And(l, r) => Self::try_from(*l)?.and(Self::try_from(*r)?),
CoreFilter::Or(l, r) => Self::try_from(*l)?.or(Self::try_from(*r)?),
};
Ok(filter)
}
}