use std::cmp::Ordering;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::lexical::core::field::FieldValue;
use crate::lexical::query::Hit;
use crate::lexical::query::Query;
use crate::lexical::reader::LexicalIndexReader;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct FacetPath {
pub field: String,
pub path: Vec<String>,
}
impl FacetPath {
pub fn new(field: String, path: Vec<String>) -> Self {
FacetPath { field, path }
}
pub fn from_value(field: String, value: String) -> Self {
FacetPath {
field,
path: vec![value],
}
}
pub fn from_delimited(field: String, path_str: &str, delimiter: &str) -> Self {
let path = path_str.split(delimiter).map(|s| s.to_string()).collect();
FacetPath { field, path }
}
pub fn depth(&self) -> usize {
self.path.len()
}
pub fn is_parent_of(&self, other: &FacetPath) -> bool {
if self.field != other.field || self.depth() >= other.depth() {
return false;
}
self.path.iter().zip(other.path.iter()).all(|(a, b)| a == b)
}
pub fn parent(&self) -> Option<FacetPath> {
if self.path.len() > 1 {
let mut parent_path = self.path.clone();
parent_path.pop();
Some(FacetPath {
field: self.field.clone(),
path: parent_path,
})
} else {
None
}
}
pub fn child(&self, component: String) -> FacetPath {
let mut child_path = self.path.clone();
child_path.push(component);
FacetPath {
field: self.field.clone(),
path: child_path,
}
}
pub fn to_string_with_delimiter(&self, delimiter: &str) -> String {
self.path.join(delimiter)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetCount {
pub path: FacetPath,
pub count: u64,
pub children: Vec<FacetCount>,
}
impl FacetCount {
pub fn new(path: FacetPath, count: u64) -> Self {
FacetCount {
path,
count,
children: Vec::new(),
}
}
pub fn add_child(&mut self, child: FacetCount) {
self.children.push(child);
}
pub fn sort_children(&mut self, by_count: bool) {
if by_count {
self.children.sort_by(|a, b| b.count.cmp(&a.count));
} else {
self.children
.sort_by(|a, b| a.path.path.last().cmp(&b.path.path.last()));
}
for child in &mut self.children {
child.sort_children(by_count);
}
}
}
#[derive(Debug, Clone)]
pub struct FacetConfig {
pub max_facets_per_field: usize,
pub max_depth: usize,
pub min_count: u64,
pub include_zero_counts: bool,
pub sort_by_count: bool,
}
impl Default for FacetConfig {
fn default() -> Self {
FacetConfig {
max_facets_per_field: 100,
max_depth: 10,
min_count: 1,
include_zero_counts: false,
sort_by_count: true,
}
}
}
#[derive(Debug)]
pub struct FacetCollector {
config: FacetConfig,
facet_counts: HashMap<FacetPath, u64>,
facet_fields: Vec<String>,
}
impl FacetCollector {
pub fn new(config: FacetConfig, facet_fields: Vec<String>) -> Self {
FacetCollector {
config,
facet_counts: HashMap::new(),
facet_fields,
}
}
pub fn collect_doc(&mut self, doc_id: u64, reader: &dyn LexicalIndexReader) -> Result<()> {
for field_name in &self.facet_fields {
let facet_values = self.get_doc_facet_values(doc_id, field_name, reader)?;
for facet_path in facet_values {
*self.facet_counts.entry(facet_path.clone()).or_insert(0) += 1;
let mut current_path = facet_path;
while let Some(parent_path) = current_path.parent() {
*self.facet_counts.entry(parent_path.clone()).or_insert(0) += 1;
current_path = parent_path;
}
}
}
Ok(())
}
fn get_doc_facet_values(
&self,
doc_id: u64,
field_name: &str,
reader: &dyn LexicalIndexReader,
) -> Result<Vec<FacetPath>> {
let mut facet_paths = Vec::new();
match reader.document(doc_id) {
Ok(Some(document)) => {
if let Some(val) = document.get(field_name) {
match val {
crate::data::DataValue::Text(value) => {
if value.contains('/') {
facet_paths.push(FacetPath::from_delimited(
field_name.to_string(),
value,
"/",
));
} else {
facet_paths.push(FacetPath::from_value(
field_name.to_string(),
value.clone(),
));
}
}
crate::data::DataValue::Int64(value) => {
facet_paths.push(FacetPath::from_value(
field_name.to_string(),
value.to_string(),
));
}
crate::data::DataValue::Float64(value) => {
facet_paths.push(FacetPath::from_value(
field_name.to_string(),
value.to_string(),
));
}
crate::data::DataValue::Bool(value) => {
facet_paths.push(FacetPath::from_value(
field_name.to_string(),
value.to_string(),
));
}
_ => {
facet_paths.push(FacetPath::from_value(
field_name.to_string(),
format!("{val:?}"),
));
}
}
}
}
Ok(None) => {
}
Err(_) => {
facet_paths.push(FacetPath::from_value(
field_name.to_string(),
format!("value_{}", doc_id % 5), ));
}
}
Ok(facet_paths)
}
pub fn finalize(self) -> Result<FacetResults> {
let mut field_facets: HashMap<String, Vec<FacetCount>> = HashMap::new();
for (facet_path, count) in self.facet_counts {
if count >= self.config.min_count {
field_facets
.entry(facet_path.field.clone())
.or_default()
.push(FacetCount::new(facet_path, count));
}
}
for facet_counts in field_facets.values_mut() {
FacetCollector::build_hierarchy_static(facet_counts);
if self.config.sort_by_count {
facet_counts.sort_by(|a, b| b.count.cmp(&a.count));
} else {
facet_counts.sort_by(|a, b| a.path.path.first().cmp(&b.path.path.first()));
}
facet_counts.truncate(self.config.max_facets_per_field);
for facet_count in facet_counts {
facet_count.sort_children(self.config.sort_by_count);
}
}
Ok(FacetResults { field_facets })
}
fn build_hierarchy_static(facet_counts: &mut [FacetCount]) {
facet_counts.sort_by(|a, b| a.path.depth().cmp(&b.path.depth()));
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetResults {
pub field_facets: HashMap<String, Vec<FacetCount>>,
}
impl FacetResults {
pub fn empty() -> Self {
FacetResults {
field_facets: HashMap::new(),
}
}
pub fn get_field_facets(&self, field_name: &str) -> Option<&Vec<FacetCount>> {
self.field_facets.get(field_name)
}
pub fn total_facet_count(&self) -> usize {
self.field_facets.values().map(|facets| facets.len()).sum()
}
pub fn merge(&mut self, other: FacetResults) {
for (field, other_facets) in other.field_facets {
let field_facets = self.field_facets.entry(field).or_default();
field_facets.extend(other_facets);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetFilter {
pub required_paths: Vec<FacetPath>,
pub excluded_paths: Vec<FacetPath>,
}
impl FacetFilter {
pub fn new() -> Self {
FacetFilter {
required_paths: Vec::new(),
excluded_paths: Vec::new(),
}
}
pub fn require(&mut self, path: FacetPath) {
self.required_paths.push(path);
}
pub fn exclude(&mut self, path: FacetPath) {
self.excluded_paths.push(path);
}
pub fn matches_doc(&self, doc_facets: &[FacetPath]) -> bool {
for required_path in &self.required_paths {
let matches = doc_facets.iter().any(|doc_facet| {
doc_facet == required_path || required_path.is_parent_of(doc_facet)
});
if !matches {
return false;
}
}
for excluded_path in &self.excluded_paths {
let matches = doc_facets.iter().any(|doc_facet| {
doc_facet == excluded_path || excluded_path.is_parent_of(doc_facet)
});
if matches {
return false;
}
}
true
}
}
impl Default for FacetFilter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct FacetedSearchEngine {
facet_config: FacetConfig,
}
impl FacetedSearchEngine {
pub fn new(facet_config: FacetConfig) -> Self {
FacetedSearchEngine { facet_config }
}
pub fn search<Q: Query>(
&self,
query: Q,
facet_fields: Vec<String>,
facet_filter: Option<FacetFilter>,
reader: &dyn LexicalIndexReader,
) -> Result<FacetedSearchResults> {
let _matcher = query.matcher(reader)?;
let _scorer = query.scorer(reader)?;
let mut hits = Vec::new();
let mut facet_collector = FacetCollector::new(self.facet_config.clone(), facet_fields);
for doc_id in 0..10u64 {
let score = 1.0f32;
if let Some(ref filter) = facet_filter {
let doc_facets = self.get_document_facets(doc_id, reader)?;
if !filter.matches_doc(&doc_facets) {
continue;
}
}
hits.push(Hit {
doc_id,
score,
fields: HashMap::new(), });
facet_collector.collect_doc(doc_id, reader)?;
}
hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
let facet_results = facet_collector.finalize()?;
let total_hits = hits.len() as u64;
Ok(FacetedSearchResults {
hits,
facets: facet_results,
total_hits,
})
}
fn get_document_facets(
&self,
_doc_id: u64,
_reader: &dyn LexicalIndexReader,
) -> Result<Vec<FacetPath>> {
Ok(vec![])
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FacetedSearchResults {
pub hits: Vec<Hit>,
pub facets: FacetResults,
pub total_hits: u64,
}
impl FacetedSearchResults {
pub fn empty() -> Self {
FacetedSearchResults {
hits: Vec::new(),
facets: FacetResults::empty(),
total_hits: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct FacetField {
pub name: String,
pub hierarchical: bool,
pub delimiter: String,
pub stored: bool,
}
impl FacetField {
pub fn new(name: String) -> Self {
FacetField {
name,
hierarchical: false,
delimiter: "/".to_string(),
stored: true,
}
}
pub fn hierarchical(mut self, delimiter: String) -> Self {
self.hierarchical = true;
self.delimiter = delimiter;
self
}
pub fn stored(mut self, stored: bool) -> Self {
self.stored = stored;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroupConfig {
pub group_field: String,
pub max_groups: usize,
pub max_docs_per_group: usize,
pub sort_by_count: bool,
}
impl Default for GroupConfig {
fn default() -> Self {
GroupConfig {
group_field: String::new(),
max_groups: 100,
max_docs_per_group: 10,
sort_by_count: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchGroup {
pub group_key: String,
pub documents: Vec<Hit>,
pub total_docs: u64,
pub representative_doc: Option<Hit>,
}
impl SearchGroup {
pub fn new(group_key: String) -> Self {
SearchGroup {
group_key,
documents: Vec::new(),
total_docs: 0,
representative_doc: None,
}
}
pub fn add_document(&mut self, hit: Hit) {
if self.representative_doc.is_none()
|| hit.score > self.representative_doc.as_ref().unwrap().score
{
self.representative_doc = Some(hit.clone());
}
self.documents.push(hit);
self.total_docs += 1;
}
pub fn sort_by_score(&mut self) {
self.documents
.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
}
pub fn limit_documents(&mut self, max_docs: usize) {
if self.documents.len() > max_docs {
self.documents.truncate(max_docs);
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GroupedSearchResults {
pub groups: Vec<SearchGroup>,
pub total_docs: u64,
pub total_groups: u64,
pub group_config: GroupConfig,
}
impl GroupedSearchResults {
pub fn empty(group_config: GroupConfig) -> Self {
GroupedSearchResults {
groups: Vec::new(),
total_docs: 0,
total_groups: 0,
group_config,
}
}
pub fn group_count(&self) -> usize {
self.groups.len()
}
pub fn get_group(&self, group_key: &str) -> Option<&SearchGroup> {
self.groups.iter().find(|g| g.group_key == group_key)
}
}
#[derive(Debug)]
pub struct GroupedSearchEngine {
group_config: GroupConfig,
}
impl GroupedSearchEngine {
pub fn new(group_config: GroupConfig) -> Self {
GroupedSearchEngine { group_config }
}
pub fn search<Q: Query>(
&self,
query: Q,
reader: &dyn LexicalIndexReader,
) -> Result<GroupedSearchResults> {
let _matcher = query.matcher(reader)?;
let scorer = query.scorer(reader)?;
let mut groups: HashMap<String, SearchGroup> = HashMap::new();
let mut total_docs = 0u64;
for doc_id in 0..100u64 {
let score = scorer.score(doc_id, 1.0, None);
if score > 0.0 {
let group_key = self.get_document_group_key(doc_id, reader)?;
let hit = Hit {
doc_id,
score,
fields: self.load_document_fields(doc_id, reader)?,
};
groups
.entry(group_key.clone())
.or_insert_with(|| SearchGroup::new(group_key))
.add_document(hit);
total_docs += 1;
}
}
let mut group_vec: Vec<SearchGroup> = groups.into_values().collect();
if self.group_config.sort_by_count {
group_vec.sort_by(|a, b| b.total_docs.cmp(&a.total_docs));
} else {
group_vec.sort_by(|a, b| a.group_key.cmp(&b.group_key));
}
for group in &mut group_vec {
group.sort_by_score();
group.limit_documents(self.group_config.max_docs_per_group);
}
let total_groups = group_vec.len() as u64;
group_vec.truncate(self.group_config.max_groups);
Ok(GroupedSearchResults {
groups: group_vec,
total_docs,
total_groups,
group_config: self.group_config.clone(),
})
}
fn get_document_group_key(
&self,
doc_id: u64,
reader: &dyn LexicalIndexReader,
) -> Result<String> {
match reader.document(doc_id) {
Ok(Some(document)) => {
if let Some(field_value) = document.get_field(&self.group_config.group_field) {
match field_value {
FieldValue::Text(value) => Ok(value.clone()),
FieldValue::Int64(value) => Ok(value.to_string()),
FieldValue::Float64(value) => Ok(value.to_string()),
FieldValue::Bool(value) => Ok(value.to_string()),
_ => Ok(format!("{field_value:?}")),
}
} else {
Ok("unknown".to_string())
}
}
_ => {
Ok(format!("group_{}", doc_id % 5))
}
}
}
fn load_document_fields(
&self,
doc_id: u64,
reader: &dyn LexicalIndexReader,
) -> Result<HashMap<String, String>> {
let mut fields = HashMap::new();
match reader.document(doc_id) {
Ok(Some(document)) => {
for (field_name, field_value) in &document.fields {
let value_str = match field_value {
crate::data::DataValue::Text(value) => value.clone(),
crate::data::DataValue::Int64(value) => value.to_string(),
crate::data::DataValue::Float64(value) => value.to_string(),
crate::data::DataValue::Bool(value) => value.to_string(),
_ => format!("{field_value:?}"),
};
fields.insert(field_name.clone(), value_str);
}
}
_ => {
fields.insert("id".to_string(), doc_id.to_string());
fields.insert("title".to_string(), format!("Document {doc_id}"));
}
}
Ok(fields)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangeFacet {
pub field: String,
pub ranges: Vec<FacetRange>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetRange {
pub label: String,
pub min: Option<f64>,
pub max: Option<f64>,
pub count: u64,
}
impl FacetRange {
pub fn new(label: String, min: Option<f64>, max: Option<f64>) -> Self {
FacetRange {
label,
min,
max,
count: 0,
}
}
pub fn contains(&self, value: f64) -> bool {
let min_ok = self.min.is_none_or(|min| value >= min);
let max_ok = self.max.is_none_or(|max| value < max);
min_ok && max_ok
}
}
impl RangeFacet {
pub fn new(field: String, ranges: Vec<FacetRange>) -> Self {
RangeFacet { field, ranges }
}
pub fn numeric_ranges(field: String, min: f64, max: f64, count: usize) -> Self {
let mut ranges = Vec::new();
let step = (max - min) / count as f64;
for i in 0..count {
let range_min = min + (i as f64 * step);
let range_max = if i == count - 1 {
None
} else {
Some(min + ((i + 1) as f64 * step))
};
let label = if let Some(max_val) = range_max {
format!("[{range_min:.1} TO {max_val:.1})")
} else {
format!("[{range_min:.1} TO *]")
};
ranges.push(FacetRange::new(label, Some(range_min), range_max));
}
RangeFacet::new(field, ranges)
}
pub fn count_ranges(&mut self, values: &[f64]) {
for range in &mut self.ranges {
range.count = 0;
}
for &value in values {
for range in &mut self.ranges {
if range.contains(value) {
range.count += 1;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_facet_path_creation() {
let path = FacetPath::new(
"category".to_string(),
vec!["Electronics".to_string(), "Computers".to_string()],
);
assert_eq!(path.field, "category");
assert_eq!(path.depth(), 2);
let single_path = FacetPath::from_value("brand".to_string(), "Apple".to_string());
assert_eq!(single_path.depth(), 1);
assert_eq!(single_path.path[0], "Apple");
let delimited_path =
FacetPath::from_delimited("tags".to_string(), "tech/computers/laptops", "/");
assert_eq!(delimited_path.depth(), 3);
assert_eq!(delimited_path.path, vec!["tech", "computers", "laptops"]);
}
#[test]
fn test_facet_path_hierarchy() {
let parent = FacetPath::new("category".to_string(), vec!["Electronics".to_string()]);
let child = FacetPath::new(
"category".to_string(),
vec!["Electronics".to_string(), "Computers".to_string()],
);
assert!(parent.is_parent_of(&child));
assert!(!child.is_parent_of(&parent));
let grandchild = child.child("Laptops".to_string());
assert_eq!(grandchild.depth(), 3);
assert!(child.is_parent_of(&grandchild));
assert!(parent.is_parent_of(&grandchild));
let child_parent = child.parent().unwrap();
assert_eq!(child_parent, parent);
}
#[test]
fn test_facet_count() {
let path = FacetPath::from_value("category".to_string(), "Electronics".to_string());
let mut facet_count = FacetCount::new(path, 42);
assert_eq!(facet_count.count, 42);
assert_eq!(facet_count.children.len(), 0);
let child_path = FacetPath::from_value("category".to_string(), "Computers".to_string());
let child_count = FacetCount::new(child_path, 15);
facet_count.add_child(child_count);
assert_eq!(facet_count.children.len(), 1);
assert_eq!(facet_count.children[0].count, 15);
}
#[test]
fn test_facet_filter() {
let mut filter = FacetFilter::new();
filter.require(FacetPath::from_value(
"category".to_string(),
"Electronics".to_string(),
));
filter.exclude(FacetPath::from_value(
"brand".to_string(),
"Acme".to_string(),
));
let doc_facets = vec![
FacetPath::from_value("category".to_string(), "Electronics".to_string()),
FacetPath::from_value("brand".to_string(), "Apple".to_string()),
];
assert!(filter.matches_doc(&doc_facets));
let doc_facets2 = vec![FacetPath::from_value(
"category".to_string(),
"Books".to_string(),
)];
assert!(!filter.matches_doc(&doc_facets2));
let doc_facets3 = vec![
FacetPath::from_value("category".to_string(), "Electronics".to_string()),
FacetPath::from_value("brand".to_string(), "Acme".to_string()),
];
assert!(!filter.matches_doc(&doc_facets3));
}
#[test]
fn test_facet_config() {
let config = FacetConfig::default();
assert_eq!(config.max_facets_per_field, 100);
assert_eq!(config.max_depth, 10);
assert_eq!(config.min_count, 1);
assert!(!config.include_zero_counts);
assert!(config.sort_by_count);
}
#[test]
fn test_facet_results() {
let mut results = FacetResults::empty();
assert_eq!(results.total_facet_count(), 0);
let path = FacetPath::from_value("category".to_string(), "Electronics".to_string());
let facet_count = FacetCount::new(path, 42);
results
.field_facets
.insert("category".to_string(), vec![facet_count]);
assert_eq!(results.total_facet_count(), 1);
assert!(results.get_field_facets("category").is_some());
assert!(results.get_field_facets("nonexistent").is_none());
}
#[test]
fn test_group_config() {
let config = GroupConfig::default();
assert!(config.group_field.is_empty());
assert_eq!(config.max_groups, 100);
assert_eq!(config.max_docs_per_group, 10);
assert!(config.sort_by_count);
}
#[test]
fn test_search_group() {
let mut group = SearchGroup::new("Electronics".to_string());
assert_eq!(group.group_key, "Electronics");
assert_eq!(group.total_docs, 0);
assert!(group.representative_doc.is_none());
let hit1 = Hit {
doc_id: 1,
score: 0.8,
fields: HashMap::new(),
};
let hit2 = Hit {
doc_id: 2,
score: 0.9,
fields: HashMap::new(),
};
group.add_document(hit1);
group.add_document(hit2);
assert_eq!(group.total_docs, 2);
assert_eq!(group.documents.len(), 2);
assert_eq!(group.representative_doc.as_ref().unwrap().score, 0.9);
group.sort_by_score();
assert_eq!(group.documents[0].score, 0.9);
assert_eq!(group.documents[1].score, 0.8);
group.limit_documents(1);
assert_eq!(group.documents.len(), 1);
}
#[test]
fn test_grouped_search_results() {
let config = GroupConfig {
group_field: "category".to_string(),
max_groups: 10,
max_docs_per_group: 5,
sort_by_count: true,
};
let results = GroupedSearchResults::empty(config.clone());
assert_eq!(results.group_count(), 0);
assert_eq!(results.total_docs, 0);
assert_eq!(results.total_groups, 0);
assert!(results.get_group("Electronics").is_none());
}
#[test]
fn test_facet_range() {
let range = FacetRange::new("[0.0 TO 10.0)".to_string(), Some(0.0), Some(10.0));
assert!(range.contains(5.0));
assert!(range.contains(0.0)); assert!(!range.contains(10.0)); assert!(!range.contains(-1.0));
assert!(!range.contains(15.0));
}
#[test]
fn test_range_facet_creation() {
let range_facet = RangeFacet::numeric_ranges("price".to_string(), 0.0, 100.0, 5);
assert_eq!(range_facet.field, "price");
assert_eq!(range_facet.ranges.len(), 5);
assert_eq!(range_facet.ranges[0].min, Some(0.0));
assert_eq!(range_facet.ranges[0].max, Some(20.0));
assert_eq!(range_facet.ranges[4].min, Some(80.0));
assert_eq!(range_facet.ranges[4].max, None); }
#[test]
fn test_range_facet_counting() {
let mut range_facet = RangeFacet::numeric_ranges("score".to_string(), 0.0, 10.0, 2);
let values = vec![1.0, 3.0, 7.0, 9.0, 15.0];
range_facet.count_ranges(&values);
assert_eq!(range_facet.ranges[0].count, 2);
assert_eq!(range_facet.ranges[1].count, 3);
}
}