use serde::{Deserialize, Serialize};
fn deserialize_fields<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
use serde_json::Value;
let value: Option<Value> = Option::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(Value::String(s)) => Ok(Some(s)),
Some(Value::Array(arr)) => {
let strings: Result<Vec<String>, _> = arr
.into_iter()
.map(|v| {
v.as_str()
.map(|s| s.to_string())
.ok_or_else(|| Error::custom("Array elements must be strings"))
})
.collect();
Ok(Some(strings?.join(",")))
}
Some(_) => Err(Error::custom("fields must be a string or array of strings")),
}
}
fn deserialize_weights<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
use serde_json::Value;
let value: Option<Value> = Option::deserialize(deserializer)?;
match value {
None => Ok(None),
Some(Value::String(s)) => Ok(Some(s)),
Some(Value::Object(map)) => {
let weights: Vec<String> = map
.into_iter()
.filter_map(|(key, val)| val.as_f64().map(|weight| format!("{}:{}", key, weight)))
.collect();
if weights.is_empty() {
return Err(Error::custom("weights object must contain numeric values"));
}
Ok(Some(weights.join(",")))
}
Some(_) => Err(Error::custom(
"weights must be a string or object with numeric values",
)),
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SearchQuery {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub case_sensitive: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fuzzy: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_score: Option<f64>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_fields"
)]
pub fields: Option<String>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_weights"
)]
pub weights: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_stemming: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub boost_exact: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_edit_distance: Option<u32>,
#[serde(default)]
pub bypass_ripple: Option<bool>,
#[serde(default)]
pub bypass_cache: Option<bool>,
#[serde(default)]
pub limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector: Option<Vec<f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_field: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_metric: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_k: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_threshold: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text_weight: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector_weight: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub select_fields: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude_fields: Option<Vec<String>>,
}
impl SearchQuery {
pub fn new(query: impl Into<String>) -> Self {
Self {
query: query.into(),
..Default::default()
}
}
pub fn language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn case_sensitive(mut self, enabled: bool) -> Self {
self.case_sensitive = Some(enabled);
self
}
pub fn fuzzy(mut self, enabled: bool) -> Self {
self.fuzzy = Some(enabled);
self
}
pub fn min_score(mut self, score: f64) -> Self {
self.min_score = Some(score);
self
}
pub fn fields(mut self, fields: impl Into<String>) -> Self {
self.fields = Some(fields.into());
self
}
pub fn weights(mut self, weights: impl Into<String>) -> Self {
self.weights = Some(weights.into());
self
}
pub fn enable_stemming(mut self, enabled: bool) -> Self {
self.enable_stemming = Some(enabled);
self
}
pub fn boost_exact(mut self, enabled: bool) -> Self {
self.boost_exact = Some(enabled);
self
}
pub fn max_edit_distance(mut self, distance: u32) -> Self {
self.max_edit_distance = Some(distance);
self
}
pub fn vector(mut self, vector: Vec<f64>) -> Self {
self.vector = Some(vector);
self
}
pub fn vector_field(mut self, field: impl Into<String>) -> Self {
self.vector_field = Some(field.into());
self
}
pub fn vector_metric(mut self, metric: impl Into<String>) -> Self {
self.vector_metric = Some(metric.into());
self
}
pub fn vector_k(mut self, k: usize) -> Self {
self.vector_k = Some(k);
self
}
pub fn vector_threshold(mut self, threshold: f64) -> Self {
self.vector_threshold = Some(threshold);
self
}
pub fn text_weight(mut self, weight: f64) -> Self {
self.text_weight = Some(weight);
self
}
pub fn vector_weight(mut self, weight: f64) -> Self {
self.vector_weight = Some(weight);
self
}
pub fn bypass_ripple(mut self, bypass: bool) -> Self {
self.bypass_ripple = Some(bypass);
self
}
pub fn bypass_cache(mut self, bypass: bool) -> Self {
self.bypass_cache = Some(bypass);
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn select_fields(mut self, fields: Vec<String>) -> Self {
self.select_fields = Some(fields);
self
}
pub fn exclude_fields(mut self, fields: Vec<String>) -> Self {
self.exclude_fields = Some(fields);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub record: serde_json::Value,
pub score: f64,
pub matched_fields: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResponse {
pub results: Vec<SearchResult>,
pub total: usize,
pub execution_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DistinctValuesQuery {
#[serde(skip_serializing_if = "Option::is_none")]
pub filter: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bypass_ripple: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bypass_cache: Option<bool>,
}
impl DistinctValuesQuery {
pub fn new() -> Self {
Self::default()
}
pub fn filter(mut self, filter: serde_json::Value) -> Self {
self.filter = Some(filter);
self
}
pub fn bypass_ripple(mut self, bypass: bool) -> Self {
self.bypass_ripple = Some(bypass);
self
}
pub fn bypass_cache(mut self, bypass: bool) -> Self {
self.bypass_cache = Some(bypass);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistinctValuesResponse {
pub collection: String,
pub field: String,
pub values: Vec<serde_json::Value>,
pub count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_search_query_builder() {
let query = SearchQuery::new("test query")
.language("english")
.fuzzy(true)
.min_score(0.5);
assert_eq!(query.query, "test query");
assert_eq!(query.language, Some("english".to_string()));
assert_eq!(query.fuzzy, Some(true));
assert_eq!(query.min_score, Some(0.5));
}
#[test]
fn test_vector_search_params() {
let query = SearchQuery::new("test")
.vector(vec![0.1, 0.2, 0.3])
.vector_field("embedding")
.vector_metric("cosine")
.vector_k(5)
.vector_threshold(0.8);
assert_eq!(query.vector, Some(vec![0.1, 0.2, 0.3]));
assert_eq!(query.vector_field, Some("embedding".to_string()));
assert_eq!(query.vector_metric, Some("cosine".to_string()));
assert_eq!(query.vector_k, Some(5));
assert_eq!(query.vector_threshold, Some(0.8));
}
#[test]
fn test_hybrid_search_weights() {
let query = SearchQuery::new("test").text_weight(0.7).vector_weight(0.3);
assert_eq!(query.text_weight, Some(0.7));
assert_eq!(query.vector_weight, Some(0.3));
}
#[test]
fn test_deserialize_fields_from_string() {
let json = json!({
"query": "test",
"fields": "name,email,description"
});
let query: SearchQuery = serde_json::from_value(json).unwrap();
assert_eq!(query.fields, Some("name,email,description".to_string()));
}
#[test]
fn test_deserialize_fields_from_array() {
let json = json!({
"query": "test",
"fields": ["name", "email", "description"]
});
let query: SearchQuery = serde_json::from_value(json).unwrap();
assert_eq!(query.fields, Some("name,email,description".to_string()));
}
#[test]
fn test_deserialize_weights_from_string() {
let json = json!({
"query": "test",
"weights": "name:2.0,email:1.5"
});
let query: SearchQuery = serde_json::from_value(json).unwrap();
assert_eq!(query.weights, Some("name:2.0,email:1.5".to_string()));
}
#[test]
fn test_deserialize_weights_from_object() {
let json = json!({
"query": "test",
"weights": {
"name": 2.0,
"email": 1.5,
"description": 1.0
}
});
let query: SearchQuery = serde_json::from_value(json).unwrap();
assert!(query.weights.is_some());
let weights = query.weights.unwrap();
assert!(weights.contains("name:2"));
assert!(weights.contains("email:1.5"));
assert!(weights.contains("description:1"));
}
#[test]
fn test_search_query_serialization() {
let query = SearchQuery::new("test query")
.language("english")
.fuzzy(true)
.fields("name,email");
let json = serde_json::to_value(&query).unwrap();
assert_eq!(json["query"], "test query");
assert_eq!(json["language"], "english");
assert_eq!(json["fuzzy"], true);
assert_eq!(json["fields"], "name,email");
}
#[test]
fn test_bypass_flags() {
let query = SearchQuery::new("test")
.bypass_cache(true)
.bypass_ripple(true);
assert_eq!(query.bypass_cache, Some(true));
assert_eq!(query.bypass_ripple, Some(true));
}
#[test]
fn test_max_edit_distance() {
let query = SearchQuery::new("test").max_edit_distance(2);
assert_eq!(query.max_edit_distance, Some(2));
}
#[test]
fn test_enable_stemming_and_boost_exact() {
let query = SearchQuery::new("test")
.enable_stemming(true)
.boost_exact(true);
assert_eq!(query.enable_stemming, Some(true));
assert_eq!(query.boost_exact, Some(true));
}
#[test]
fn test_limit() {
let query = SearchQuery::new("test").limit(10);
assert_eq!(query.limit, Some(10));
}
#[test]
fn test_limit_serialization() {
let query = SearchQuery::new("test").limit(5);
let json = serde_json::to_value(&query).unwrap();
assert_eq!(json["limit"], 5);
}
#[test]
fn test_distinct_values_query_default_is_empty() {
let q = DistinctValuesQuery::new();
assert!(q.filter.is_none());
assert!(q.bypass_ripple.is_none());
assert!(q.bypass_cache.is_none());
}
#[test]
fn test_distinct_values_query_builder() {
let filter = json!({"type": "Condition", "content": {"field": "status", "operator": "Eq", "value": "active"}});
let q = DistinctValuesQuery::new()
.filter(filter.clone())
.bypass_ripple(true)
.bypass_cache(false);
assert_eq!(q.filter.as_ref().unwrap(), &filter);
assert_eq!(q.bypass_ripple, Some(true));
assert_eq!(q.bypass_cache, Some(false));
}
#[test]
fn test_distinct_values_query_skips_none_fields() {
let q = DistinctValuesQuery::new();
let json = serde_json::to_value(&q).unwrap();
assert!(json.get("filter").is_none());
assert!(json.get("bypass_ripple").is_none());
assert!(json.get("bypass_cache").is_none());
}
#[test]
fn test_distinct_values_response_deserialization() {
let raw = json!({
"collection": "products",
"field": "category",
"values": ["books", "electronics", "food"],
"count": 3
});
let resp: DistinctValuesResponse = serde_json::from_value(raw).unwrap();
assert_eq!(resp.collection, "products");
assert_eq!(resp.field, "category");
assert_eq!(resp.count, 3);
assert_eq!(resp.values.len(), 3);
assert_eq!(resp.values[0].as_str(), Some("books"));
}
}