luci/search/
segment_store.rs1use std::sync::Arc;
15
16use crate::analysis::AnalyzerRegistry;
17use crate::core::DocId;
18
19use crate::segment::reader::SegmentReader;
20use crate::vector::global::GlobalHnsw;
21
22pub struct SegmentStore {
28 pub(crate) segments: Vec<SegmentReader>,
29 pub(crate) analyzers: AnalyzerRegistry,
30 pub(crate) mapping: Option<crate::mapping::Mapping>,
31 pub(crate) avg_field_lengths: std::collections::HashMap<String, f32>,
32 pub(crate) global_hnsw: Option<Arc<GlobalHnsw>>,
37}
38
39impl SegmentStore {
40 pub fn new(
44 segments: Vec<SegmentReader>,
45 analyzers: AnalyzerRegistry,
46 mapping: Option<crate::mapping::Mapping>,
47 global_hnsw: Option<Arc<GlobalHnsw>>,
48 ) -> Self {
49 let avg_field_lengths = Self::compute_avg_field_lengths(&segments);
50 Self {
51 segments,
52 analyzers,
53 mapping,
54 avg_field_lengths,
55 global_hnsw,
56 }
57 }
58
59 pub(crate) fn global_hnsw(&self) -> Option<&Arc<GlobalHnsw>> {
61 self.global_hnsw.as_ref()
62 }
63
64 pub(crate) fn segments(&self) -> &[SegmentReader] {
66 &self.segments
67 }
68
69 pub(crate) fn total_docs(&self) -> u32 {
71 self.segments.iter().map(|s| s.doc_count()).sum()
72 }
73
74 pub(crate) fn avg_field_length(&self, field_name: &str) -> f32 {
76 self.avg_field_lengths
77 .get(field_name)
78 .copied()
79 .unwrap_or(0.0)
80 }
81
82 pub(crate) fn doc_freq(&self, field_name: &str, term: &str) -> u32 {
84 self.segments
85 .iter()
86 .map(|s| {
87 let field_id = s
88 .header()
89 .fields
90 .iter()
91 .find(|f| f.field_name == field_name)
92 .map(|f| f.field_id);
93 match field_id {
94 Some(fid) => s.doc_freq(fid, term),
95 None => 0,
96 }
97 })
98 .sum()
99 }
100
101 pub(crate) fn analyzers(&self) -> &AnalyzerRegistry {
103 &self.analyzers
104 }
105
106 pub(crate) fn mapping(&self) -> Option<&crate::mapping::Mapping> {
108 self.mapping.as_ref()
109 }
110
111 pub(crate) fn resolve_search_analyzer<'a>(
113 &'a self,
114 field: &str,
115 query_analyzer: Option<&'a str>,
116 ) -> &'a str {
117 if let Some(a) = query_analyzer {
118 return a;
119 }
120 if let Some(ref mapping) = self.mapping {
121 if let Some(field_id) = mapping.field_id(field) {
122 let field_mapping = mapping.field(field_id);
123 if let Some(ref sa) = field_mapping.search_analyzer {
124 return sa;
125 }
126 if let Some(ref a) = field_mapping.analyzer {
127 return a;
128 }
129 }
130 }
131 "standard"
132 }
133
134 fn compute_avg_field_lengths(
136 segments: &[SegmentReader],
137 ) -> std::collections::HashMap<String, f32> {
138 let mut field_names: std::collections::HashSet<String> = std::collections::HashSet::new();
139 for segment in segments {
140 for field in &segment.header().fields {
141 field_names.insert(field.field_name.clone());
142 }
143 }
144
145 let mut result = std::collections::HashMap::new();
146 for field_name in &field_names {
147 let mut total_length = 0.0f64;
148 let mut total_docs = 0u32;
149 for segment in segments {
150 let field_id = segment
151 .header()
152 .fields
153 .iter()
154 .find(|f| &f.field_name == field_name)
155 .map(|f| f.field_id);
156 if let Some(fid) = field_id {
157 if let Some(norms) = segment.norms(fid) {
158 for i in 0..norms.doc_count() {
159 total_length += norms.norm(DocId::new(i)) as f64;
160 }
161 total_docs += norms.doc_count();
162 }
163 }
164 }
165 if total_docs > 0 {
166 result.insert(
167 field_name.clone(),
168 (total_length / total_docs as f64) as f32,
169 );
170 }
171 }
172 result
173 }
174}