use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::lexical::core::field::FieldValue;
use crate::lexical::query::Query;
use crate::lexical::query::QueryResult;
use crate::lexical::reader::LexicalIndexReader;
use crate::lexical::search::features::facet::{FacetCollector, FacetResults};
use crate::lexical::search::features::highlight::{HighlightConfig, Highlighter};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultProcessorConfig {
pub max_results: usize,
pub enable_highlighting: bool,
pub enable_faceting: bool,
pub enable_snippets: bool,
pub retrieve_fields: bool,
pub fields_to_retrieve: Vec<String>,
pub fields_to_highlight: Vec<String>,
pub facet_fields: Vec<String>,
pub snippet_length: usize,
pub enable_grouping: bool,
pub group_by_field: Option<String>,
}
impl Default for ResultProcessorConfig {
fn default() -> Self {
ResultProcessorConfig {
max_results: 10,
enable_highlighting: false,
enable_faceting: false,
enable_snippets: false,
retrieve_fields: true,
fields_to_retrieve: vec!["*".to_string()],
fields_to_highlight: Vec::new(),
facet_fields: Vec::new(),
snippet_length: 200,
enable_grouping: false,
group_by_field: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessedSearchResults {
pub hits: Vec<ProcessedHit>,
pub total_hits: u64,
pub max_score: f32,
pub facets: FacetResults,
pub aggregations: HashMap<String, AggregationResult>,
pub suggestions: Vec<String>,
pub groups: Option<Vec<ResultGroup>>,
pub processing_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessedHit {
pub doc_id: u64,
pub score: f32,
pub fields: HashMap<String, String>,
pub highlights: HashMap<String, Vec<String>>,
pub snippets: HashMap<String, String>,
pub explanation: Option<ScoreExplanation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultGroup {
pub group_key: String,
pub hits: Vec<ProcessedHit>,
pub total_hits: u64,
pub max_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreExplanation {
pub value: f32,
pub description: String,
pub details: Vec<ScoreExplanation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum AggregationResult {
Count { count: u64 },
Sum { sum: f64 },
Average { avg: f64, count: u64 },
MinMax { min: f64, max: f64 },
Terms { buckets: Vec<TermsBucket> },
Range { buckets: Vec<RangeBucket> },
DateHistogram { buckets: Vec<DateHistogramBucket> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TermsBucket {
pub key: String,
pub doc_count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangeBucket {
pub key: String,
pub from: Option<f64>,
pub to: Option<f64>,
pub doc_count: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DateHistogramBucket {
pub key: u64,
pub key_as_string: String,
pub doc_count: u64,
}
#[derive(Debug)]
pub struct ResultProcessor {
config: ResultProcessorConfig,
highlighter: Option<Highlighter>,
facet_collector: Option<FacetCollector>,
reader: Arc<dyn LexicalIndexReader>,
}
impl ResultProcessor {
pub fn new(config: ResultProcessorConfig, reader: Arc<dyn LexicalIndexReader>) -> Result<Self> {
let highlighter = if config.enable_highlighting {
Some(Highlighter::new(HighlightConfig::default()))
} else {
None
};
let facet_collector = if config.enable_faceting {
let facet_config = crate::lexical::search::features::facet::FacetConfig::default(); Some(FacetCollector::new(
facet_config,
config.facet_fields.clone(),
))
} else {
None
};
Ok(ResultProcessor {
config,
highlighter,
facet_collector,
reader,
})
}
pub fn process_results<Q: Query>(
&mut self,
raw_results: Vec<QueryResult>,
query: &Q,
) -> Result<ProcessedSearchResults> {
let start_time = crate::util::time::Timer::now();
let limited_results: Vec<_> = raw_results
.into_iter()
.take(self.config.max_results)
.collect();
let total_hits = limited_results.len() as u64;
let max_score = limited_results
.iter()
.map(|r| r.score)
.fold(0.0f32, f32::max);
let mut processed_hits = Vec::new();
for result in &limited_results {
let processed_hit = self.process_hit(result, query)?;
processed_hits.push(processed_hit);
}
let facets = if let Some(ref mut collector) = self.facet_collector {
for result in &limited_results {
collector.collect_doc(result.doc_id, self.reader.as_ref())?;
}
FacetResults::empty() } else {
FacetResults::empty()
};
let aggregations = self.collect_aggregations(&limited_results)?;
let groups = if self.config.enable_grouping {
Some(self.group_results(&processed_hits)?)
} else {
None
};
let suggestions = Vec::new();
let processing_time_ms = start_time.elapsed_ms();
Ok(ProcessedSearchResults {
hits: processed_hits,
total_hits,
max_score,
facets,
aggregations,
suggestions,
groups,
processing_time_ms,
})
}
fn process_hit<Q: Query>(&self, result: &QueryResult, query: &Q) -> Result<ProcessedHit> {
let fields = if self.config.retrieve_fields {
self.retrieve_document_fields(result.doc_id)?
} else {
HashMap::new()
};
let highlights = if self.config.enable_highlighting && self.highlighter.is_some() {
self.generate_highlights(result.doc_id, &fields, query)?
} else {
HashMap::new()
};
let snippets = if self.config.enable_snippets {
self.generate_snippets(result.doc_id, &fields)?
} else {
HashMap::new()
};
let explanation = None;
Ok(ProcessedHit {
doc_id: result.doc_id,
score: result.score,
fields,
highlights,
snippets,
explanation,
})
}
fn retrieve_document_fields(&self, doc_id: u64) -> Result<HashMap<String, String>> {
let mut fields = HashMap::new();
if let Some(document) = self.reader.document(doc_id)? {
for (field_name, data_value) in &document.fields {
if self.should_retrieve_field(field_name) {
let value_str = self.field_value_to_string(data_value);
fields.insert(field_name.clone(), value_str);
}
}
}
Ok(fields)
}
fn should_retrieve_field(&self, field_name: &str) -> bool {
if self.config.fields_to_retrieve.contains(&"*".to_string()) {
return true;
}
self.config
.fields_to_retrieve
.contains(&field_name.to_string())
}
fn field_value_to_string(&self, field_value: &FieldValue) -> String {
match field_value {
FieldValue::Text(s) => s.clone(),
FieldValue::Int64(i) => i.to_string(),
FieldValue::Float64(f) => f.to_string(),
FieldValue::Bool(b) => b.to_string(),
FieldValue::Bytes(data, mime) => {
let m = mime.as_deref().unwrap_or("bin");
format!("[blob: {} ({} bytes)]", m, data.len())
}
FieldValue::DateTime(dt) => dt.to_rfc3339(),
FieldValue::Geo(lat, lon) => format!("{},{}", lat, lon),
FieldValue::Null => "null".to_string(),
FieldValue::Vector(v) => format!("[vector: dim={}]", v.len()),
}
}
fn generate_highlights<Q: Query>(
&self,
_doc_id: u64,
fields: &HashMap<String, String>,
query: &Q,
) -> Result<HashMap<String, Vec<String>>> {
let mut highlights = HashMap::new();
if let Some(ref highlighter) = self.highlighter {
for field_name in &self.config.fields_to_highlight {
if let Some(field_text) = fields.get(field_name) {
let highlighted = highlighter.highlight(query, field_name, field_text)?;
let highlighted_texts: Vec<String> = highlighted
.fragments
.iter()
.map(|f| f.text.clone())
.collect();
if !highlighted_texts.is_empty() {
highlights.insert(field_name.clone(), highlighted_texts);
}
}
}
}
Ok(highlights)
}
fn generate_snippets(
&self,
_doc_id: u64,
fields: &HashMap<String, String>,
) -> Result<HashMap<String, String>> {
let mut snippets = HashMap::new();
for (field_name, field_text) in fields {
if field_text.len() > self.config.snippet_length {
let snippet = field_text
.chars()
.take(self.config.snippet_length)
.collect::<String>()
+ "...";
snippets.insert(field_name.clone(), snippet);
} else {
snippets.insert(field_name.clone(), field_text.clone());
}
}
Ok(snippets)
}
#[allow(dead_code)]
fn collect_facets(
&self,
collector: &mut FacetCollector,
results: &[QueryResult],
) -> Result<FacetResults> {
Self::collect_facets_static(collector, results, &self.reader)
}
#[allow(dead_code)]
fn collect_facets_static(
collector: &mut FacetCollector,
results: &[QueryResult],
reader: &Arc<dyn LexicalIndexReader>,
) -> Result<FacetResults> {
for result in results {
collector.collect_doc(result.doc_id, reader.as_ref())?;
}
Ok(FacetResults::empty()) }
#[allow(dead_code)]
fn collect_facets_finalize(
mut collector: FacetCollector,
results: &[QueryResult],
reader: &Arc<dyn LexicalIndexReader>,
) -> Result<FacetResults> {
for result in results {
collector.collect_doc(result.doc_id, reader.as_ref())?;
}
collector.finalize()
}
fn collect_aggregations(
&self,
results: &[QueryResult],
) -> Result<HashMap<String, AggregationResult>> {
let mut aggregations = HashMap::new();
aggregations.insert(
"total_count".to_string(),
AggregationResult::Count {
count: results.len() as u64,
},
);
if !results.is_empty() {
let scores: Vec<f64> = results.iter().map(|r| r.score as f64).collect();
let sum: f64 = scores.iter().sum();
let avg = sum / scores.len() as f64;
let min = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
aggregations.insert(
"score_stats".to_string(),
AggregationResult::Average {
avg,
count: scores.len() as u64,
},
);
aggregations.insert(
"score_range".to_string(),
AggregationResult::MinMax { min, max },
);
}
Ok(aggregations)
}
fn group_results(&self, hits: &[ProcessedHit]) -> Result<Vec<ResultGroup>> {
let group_field = match &self.config.group_by_field {
Some(field) => field,
None => return Ok(Vec::new()),
};
let mut groups: BTreeMap<String, Vec<ProcessedHit>> = BTreeMap::new();
for hit in hits {
let group_key = hit
.fields
.get(group_field)
.unwrap_or(&"[no value]".to_string())
.clone();
groups.entry(group_key).or_default().push(hit.clone());
}
let result_groups = groups
.into_iter()
.map(|(group_key, group_hits)| {
let total_hits = group_hits.len() as u64;
let max_score = group_hits.iter().map(|h| h.score).fold(0.0f32, f32::max);
ResultGroup {
group_key,
hits: group_hits,
total_hits,
max_score,
}
})
.collect();
Ok(result_groups)
}
}
#[derive(Debug)]
pub struct ResultProcessorBuilder {
config: ResultProcessorConfig,
}
impl ResultProcessorBuilder {
pub fn new() -> Self {
ResultProcessorBuilder {
config: ResultProcessorConfig::default(),
}
}
pub fn max_results(mut self, max_results: usize) -> Self {
self.config.max_results = max_results;
self
}
pub fn enable_highlighting(mut self, fields: Vec<String>) -> Self {
self.config.enable_highlighting = true;
self.config.fields_to_highlight = fields;
self
}
pub fn enable_faceting(mut self, fields: Vec<String>) -> Self {
self.config.enable_faceting = true;
self.config.facet_fields = fields;
self
}
pub fn enable_snippets(mut self, length: usize) -> Self {
self.config.enable_snippets = true;
self.config.snippet_length = length;
self
}
pub fn retrieve_fields(mut self, fields: Vec<String>) -> Self {
self.config.retrieve_fields = true;
self.config.fields_to_retrieve = fields;
self
}
pub fn enable_grouping(mut self, field: String) -> Self {
self.config.enable_grouping = true;
self.config.group_by_field = Some(field);
self
}
pub fn build(self) -> ResultProcessorConfig {
self.config
}
}
impl Default for ResultProcessorBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
#[test]
fn test_result_processor_config() {
let config = ResultProcessorConfig {
max_results: 20,
enable_highlighting: true,
fields_to_highlight: vec!["title".to_string()],
..Default::default()
};
assert_eq!(config.max_results, 20);
assert!(config.enable_highlighting);
assert_eq!(config.fields_to_highlight, vec!["title"]);
}
#[test]
fn test_result_processor_builder() {
let config = ResultProcessorBuilder::new()
.max_results(50)
.enable_highlighting(vec!["title".to_string(), "content".to_string()])
.enable_faceting(vec!["category".to_string()])
.enable_snippets(300)
.build();
assert_eq!(config.max_results, 50);
assert!(config.enable_highlighting);
assert!(config.enable_faceting);
assert!(config.enable_snippets);
assert_eq!(config.snippet_length, 300);
}
#[test]
fn test_processed_hit_structure() {
let mut fields = HashMap::new();
fields.insert("title".to_string(), "Test Title".to_string());
let hit = ProcessedHit {
doc_id: 1,
score: 0.95,
fields,
highlights: HashMap::new(),
snippets: HashMap::new(),
explanation: None,
};
assert_eq!(hit.doc_id, 1);
assert_eq!(hit.score, 0.95);
assert_eq!(hit.fields.get("title"), Some(&"Test Title".to_string()));
}
#[test]
fn test_aggregation_results() {
let count_agg = AggregationResult::Count { count: 100 };
let avg_agg = AggregationResult::Average {
avg: 0.75,
count: 50,
};
match count_agg {
AggregationResult::Count { count } => assert_eq!(count, 100),
_ => panic!("Expected Count aggregation"),
}
match avg_agg {
AggregationResult::Average { avg, count } => {
assert_eq!(avg, 0.75);
assert_eq!(count, 50);
}
_ => panic!("Expected Average aggregation"),
}
}
}