use std::sync::Arc;
use crate::analysis::AnalyzerRegistry;
use crate::core::DocId;
use crate::segment::reader::SegmentReader;
use crate::vector::global::GlobalHnsw;
pub struct SegmentStore {
pub(crate) segments: Vec<SegmentReader>,
pub(crate) analyzers: AnalyzerRegistry,
pub(crate) mapping: Option<crate::mapping::Mapping>,
pub(crate) avg_field_lengths: std::collections::HashMap<String, f32>,
pub(crate) global_hnsw: Option<Arc<GlobalHnsw>>,
}
impl SegmentStore {
pub fn new(
segments: Vec<SegmentReader>,
analyzers: AnalyzerRegistry,
mapping: Option<crate::mapping::Mapping>,
global_hnsw: Option<Arc<GlobalHnsw>>,
) -> Self {
let avg_field_lengths = Self::compute_avg_field_lengths(&segments);
Self {
segments,
analyzers,
mapping,
avg_field_lengths,
global_hnsw,
}
}
pub(crate) fn global_hnsw(&self) -> Option<&Arc<GlobalHnsw>> {
self.global_hnsw.as_ref()
}
pub(crate) fn segments(&self) -> &[SegmentReader] {
&self.segments
}
pub(crate) fn total_docs(&self) -> u32 {
self.segments.iter().map(|s| s.doc_count()).sum()
}
pub(crate) fn avg_field_length(&self, field_name: &str) -> f32 {
self.avg_field_lengths
.get(field_name)
.copied()
.unwrap_or(0.0)
}
pub(crate) fn doc_freq(&self, field_name: &str, term: &str) -> u32 {
self.segments
.iter()
.map(|s| {
let field_id = s
.header()
.fields
.iter()
.find(|f| f.field_name == field_name)
.map(|f| f.field_id);
match field_id {
Some(fid) => s.doc_freq(fid, term),
None => 0,
}
})
.sum()
}
pub(crate) fn analyzers(&self) -> &AnalyzerRegistry {
&self.analyzers
}
pub(crate) fn mapping(&self) -> Option<&crate::mapping::Mapping> {
self.mapping.as_ref()
}
pub(crate) fn resolve_search_analyzer<'a>(
&'a self,
field: &str,
query_analyzer: Option<&'a str>,
) -> &'a str {
if let Some(a) = query_analyzer {
return a;
}
if let Some(ref mapping) = self.mapping {
if let Some(field_id) = mapping.field_id(field) {
let field_mapping = mapping.field(field_id);
if let Some(ref sa) = field_mapping.search_analyzer {
return sa;
}
if let Some(ref a) = field_mapping.analyzer {
return a;
}
}
}
"standard"
}
fn compute_avg_field_lengths(
segments: &[SegmentReader],
) -> std::collections::HashMap<String, f32> {
let mut field_names: std::collections::HashSet<String> = std::collections::HashSet::new();
for segment in segments {
for field in &segment.header().fields {
field_names.insert(field.field_name.clone());
}
}
let mut result = std::collections::HashMap::new();
for field_name in &field_names {
let mut total_length = 0.0f64;
let mut total_docs = 0u32;
for segment in segments {
let field_id = segment
.header()
.fields
.iter()
.find(|f| &f.field_name == field_name)
.map(|f| f.field_id);
if let Some(fid) = field_id {
if let Some(norms) = segment.norms(fid) {
for i in 0..norms.doc_count() {
total_length += norms.norm(DocId::new(i)) as f64;
}
total_docs += norms.doc_count();
}
}
}
if total_docs > 0 {
result.insert(
field_name.clone(),
(total_length / total_docs as f64) as f32,
);
}
}
result
}
}