1use std::mem;
2use std::sync::Arc;
3use std::sync::atomic::AtomicBool;
4
5use ahash::AHashSet;
6use crate::common::counter::hardware_accumulator::HwMeasurementAcc;
7use crate::common::types::{DeferredBehavior, ScoreType};
8use ordered_float::OrderedFloat;
9use crate::segment::common::operation_error::{OperationError, OperationResult};
10use crate::segment::common::reciprocal_rank_fusion::rrf_scoring;
11use crate::segment::common::score_fusion::{ScoreFusion, score_fusion};
12use crate::segment::data_types::query_context::FormulaContext;
13use crate::segment::index::query_optimization::rescore_formula::parsed_formula::ParsedFormula;
14use crate::segment::types::{
15 Filter, HasIdCondition, ScoredPoint, WithPayload, WithPayloadInterface, WithVector,
16};
17use crate::shard::query::mmr::mmr_from_points_with_vector;
18use crate::shard::query::planned_query::*;
19use crate::shard::query::scroll::{QueryScrollRequestInternal, ScrollOrder};
20use crate::shard::query::*;
21use crate::shard::retrieve::retrieve_blocking::retrieve_blocking;
22use crate::shard::search::CoreSearchRequest;
23use crate::shard::search_result_aggregator::BatchResultAggregator;
24
25use super::EdgeShard;
26use crate::edge::DEFAULT_EDGE_TIMEOUT;
27
28impl EdgeShard {
29 pub fn query(&self, request: ShardQueryRequest) -> OperationResult<Vec<ScoredPoint>> {
30 let planned_query = PlannedQuery::try_from(vec![request])?;
31
32 let PlannedQuery {
33 root_plans,
34 searches,
35 scrolls,
36 } = planned_query;
37
38 let mut search_results = Vec::new();
39 for search in &searches {
40 search_results.push(self.search(search.clone())?);
41 }
42
43 let mut scroll_results = Vec::new();
44 for scroll in &scrolls {
45 scroll_results.push(self.query_scroll(scroll)?);
46 }
47
48 let mut scored_points_batch = Vec::new();
49 for root_plan in root_plans {
50 let scored_points = self.resolve_plan(
51 root_plan,
52 &mut search_results,
53 &mut scroll_results,
54 HwMeasurementAcc::disposable_edge(),
55 )?;
56
57 scored_points_batch.push(scored_points)
58 }
59
60 let [scored_points] = scored_points_batch
61 .try_into()
62 .map_err(|unconverted: Vec<_>| {
63 OperationError::service_error(format!(
64 "unexpected scored points batch size: expected 1, received {}",
65 unconverted.len(),
66 ))
67 })?;
68
69 Ok(scored_points)
70 }
71
72 fn resolve_plan(
73 &self,
74 root_plan: RootPlan,
75 search_results: &mut Vec<Vec<ScoredPoint>>,
76 scroll_results: &mut Vec<Vec<ScoredPoint>>,
77 hw_measurement_acc: HwMeasurementAcc,
78 ) -> OperationResult<Vec<ScoredPoint>> {
79 let RootPlan {
80 merge_plan,
81 with_payload,
82 with_vector,
83 } = root_plan;
84
85 let results = self.recurse_prefetch(
86 merge_plan,
87 search_results,
88 scroll_results,
89 0,
90 hw_measurement_acc.clone(),
91 )?;
92
93 let [result] = self
94 .fill_with_payload_or_vectors(
95 vec![results],
96 with_payload,
97 with_vector,
98 hw_measurement_acc,
99 )?
100 .try_into()
101 .map_err(|unconverted: Vec<_>| {
102 OperationError::service_error(format!(
103 "expected single result after filling payload/vectors, got {}",
104 unconverted.len(),
105 ))
106 })?;
107 Ok(result)
108 }
109
110 fn recurse_prefetch(
111 &self,
112 merge_plan: MergePlan,
113 search_results: &mut Vec<Vec<ScoredPoint>>,
114 scroll_results: &mut Vec<Vec<ScoredPoint>>,
115 depth: usize,
116 hw_counter_acc: HwMeasurementAcc,
117 ) -> OperationResult<Vec<ScoredPoint>> {
118 let MergePlan {
119 sources: merge_plan_sources,
120 rescore_stages,
121 } = merge_plan;
122
123 let max_len = merge_plan_sources.len();
124 let mut sources = Vec::with_capacity(max_len);
125
126 for source in merge_plan_sources {
128 match source {
129 Source::SearchesIdx(idx) => {
130 sources.push(take_prefetched_source(search_results, idx)?)
131 }
132
133 Source::ScrollsIdx(idx) => {
134 sources.push(take_prefetched_source(scroll_results, idx)?)
135 }
136
137 Source::Prefetch(merge_plan) => {
138 let merged = self.recurse_prefetch(
139 *merge_plan,
140 search_results,
141 scroll_results,
142 depth + 1,
143 hw_counter_acc.clone(),
144 )?;
145
146 sources.push(merged);
147 }
148 }
149 }
150
151 if let Some(rescore_stages) = rescore_stages {
152 let RescoreStages {
153 shard_level,
154 collection_level,
155 } = rescore_stages;
156
157 let shard_stage_result = if let Some(rescore_params) = shard_level {
158 vec![self.rescore(sources, rescore_params, hw_counter_acc.clone())?]
159 } else {
160 sources
161 };
162
163 let collection_result = if let Some(rescore_params) = collection_level {
164 self.rescore(shard_stage_result, rescore_params, hw_counter_acc)?
165 } else {
166 shard_stage_result.into_iter().next().unwrap_or_default()
168 };
169
170 Ok(collection_result)
172 } else {
173 debug_assert_eq!(depth, 0);
176 debug_assert_eq!(sources.len(), 1);
177 let [result] = sources.try_into().map_err(|unconverted: Vec<_>| {
178 OperationError::service_error(format!(
179 "expected single source without rescore stages, got {}",
180 unconverted.len(),
181 ))
182 })?;
183
184 Ok(result)
185 }
186 }
187
188 fn rescore(
189 &self,
190 sources: Vec<Vec<ScoredPoint>>,
191 rescore_params: RescoreParams,
192 hw_counter_acc: HwMeasurementAcc,
193 ) -> OperationResult<Vec<ScoredPoint>> {
194 let RescoreParams {
195 rescore,
196 score_threshold,
197 limit,
198 params,
199 } = rescore_params;
200
201 match rescore {
202 ScoringQuery::Fusion(fusion) => {
203 let top_fused = Self::fusion_rescore(
204 sources,
205 fusion,
206 score_threshold.map(OrderedFloat::into_inner),
207 limit,
208 )?;
209 Ok(top_fused)
210 }
211
212 ScoringQuery::OrderBy(order_by) => {
213 let filter = filter_by_point_ids(&sources);
215
216 let scroll_request = QueryScrollRequestInternal {
219 limit,
220 filter: Some(filter),
221 with_payload: false.into(),
222 with_vector: false.into(),
223 scroll_order: ScrollOrder::ByField(order_by),
224 };
225
226 self.query_scroll(&scroll_request)
227 }
228
229 ScoringQuery::Vector(query_enum) => {
230 let filter = filter_by_point_ids(&sources);
232
233 let search_request = CoreSearchRequest {
234 query: query_enum,
235 filter: Some(filter),
236 params,
237 limit,
238 offset: 0,
239 with_payload: None,
240 with_vector: None,
241 score_threshold: score_threshold.map(OrderedFloat::into_inner),
242 };
243
244 self.search(search_request)
245 }
246
247 ScoringQuery::Formula(formula) => self.rescore_with_formula(
248 formula,
249 sources,
250 limit,
251 score_threshold.map(OrderedFloat::into_inner),
252 hw_counter_acc,
253 ),
254
255 ScoringQuery::Sample(sample) => match sample {
256 SampleInternal::Random => {
257 let filter = filter_by_point_ids(&sources);
259
260 let scroll_request = QueryScrollRequestInternal {
262 limit,
263 filter: Some(filter),
264 with_payload: false.into(),
265 with_vector: false.into(),
266 scroll_order: ScrollOrder::Random,
267 };
268
269 self.query_scroll(&scroll_request)
270 }
271 },
272
273 ScoringQuery::Mmr(mmr) => self.mmr_rescore(sources, mmr, limit, hw_counter_acc),
274 }
275 }
276
277 fn fusion_rescore(
278 sources: Vec<Vec<ScoredPoint>>,
279 fusion: FusionInternal,
280 score_threshold: Option<f32>,
281 limit: usize,
282 ) -> OperationResult<Vec<ScoredPoint>> {
283 let fused = match fusion {
284 FusionInternal::Rrf { k, ref weights } => {
285 let weights_slice = weights
286 .as_ref()
287 .map(|w| w.iter().map(|f| f.into_inner()).collect::<Vec<_>>());
288 rrf_scoring(sources, k, weights_slice.as_deref())?
289 }
290 FusionInternal::Dbsf => score_fusion(sources, ScoreFusion::dbsf()),
291 };
292
293 let top_fused: Vec<_> = if let Some(score_threshold) = score_threshold {
294 fused
295 .into_iter()
296 .take_while(|point| point.score >= score_threshold)
297 .take(limit)
298 .collect()
299 } else {
300 fused.into_iter().take(limit).collect()
301 };
302
303 Ok(top_fused)
304 }
305
306 pub fn rescore_with_formula(
307 &self,
308 formula: ParsedFormula,
309 prefetches_results: Vec<Vec<ScoredPoint>>,
310 limit: usize,
311 score_threshold: Option<ScoreType>,
312 hw_measurement_acc: HwMeasurementAcc,
313 ) -> OperationResult<Vec<ScoredPoint>> {
314 let ctx = FormulaContext {
315 formula,
316 prefetches_results,
317 limit,
318 score_threshold,
319 is_stopped: Arc::new(AtomicBool::new(false)),
320 };
321
322 let ctx = Arc::new(ctx);
323 let hw_counter = hw_measurement_acc.get_counter_cell();
324
325 let mut rescored_results = Vec::new();
326
327 let segments = self
329 .segments
330 .read()
331 .non_appendable_then_appendable_segments()
332 .collect::<Vec<_>>();
333
334 for segment in segments {
335 let rescored_result = segment
336 .get()
337 .read()
338 .rescore_with_formula(ctx.clone(), &hw_counter)?;
339
340 rescored_results.push(rescored_result);
341 }
342
343 let mut aggregator = BatchResultAggregator::new(std::iter::once(limit));
345 aggregator.update_point_versions(rescored_results.iter().flatten());
346 aggregator.update_batch_results(0, rescored_results.into_iter().flatten());
347
348 let top =
349 aggregator.into_topk().into_iter().next().ok_or_else(|| {
350 OperationError::service_error("expected first result of aggregator")
351 })?;
352
353 Ok(top)
354 }
355
356 fn mmr_rescore(
358 &self,
359 sources: Vec<Vec<ScoredPoint>>,
360 mmr: MmrInternal,
361 limit: usize,
362 hw_measurement_acc: HwMeasurementAcc,
363 ) -> OperationResult<Vec<ScoredPoint>> {
364 let points_with_vector = self
365 .fill_with_payload_or_vectors(
366 sources,
367 false.into(),
368 WithVector::from(mmr.using.clone()),
369 hw_measurement_acc.clone(),
370 )?
371 .into_iter()
372 .flatten();
373
374 let vector_data_config = self
375 .config
376 .read()
377 .vector_data_config(&mmr.using)
378 .ok_or_else(|| {
379 OperationError::service_error(format!(
380 "vector data config for vector {} not found",
381 mmr.using,
382 ))
383 })?;
384
385 let mut top_mmr = mmr_from_points_with_vector(
387 points_with_vector,
388 mmr,
389 vector_data_config.distance,
390 vector_data_config.multivector_config,
391 limit,
392 hw_measurement_acc,
393 )?;
394
395 for point in &mut top_mmr {
397 point.vector = None;
398 }
399
400 Ok(top_mmr)
401 }
402
403 fn fill_with_payload_or_vectors(
405 &self,
406 query_response: ShardQueryResponse,
407 with_payload: WithPayloadInterface,
408 with_vector: WithVector,
409 hw_measurement_acc: HwMeasurementAcc,
410 ) -> OperationResult<ShardQueryResponse> {
411 if !with_payload.is_required() && !with_vector.is_enabled() {
412 return Ok(query_response);
413 }
414
415 let point_ids: Vec<_> = query_response
417 .iter()
418 .flatten()
419 .map(|scored_point| scored_point.id)
420 .collect();
421
422 let records_map = retrieve_blocking(
423 self.segments.clone(),
424 &point_ids,
425 &WithPayload::from(with_payload),
426 &with_vector,
427 DEFAULT_EDGE_TIMEOUT,
428 &AtomicBool::new(false),
429 hw_measurement_acc,
430 DeferredBehavior::Exclude,
431 )?;
432
433 let query_response: ShardQueryResponse = query_response
436 .into_iter()
437 .map(|points| {
438 points
439 .into_iter()
440 .filter_map(|mut point| {
441 records_map.get(&point.id).map(|record| {
442 point.payload.clone_from(&record.payload);
443 point.vector.clone_from(&record.vector);
444 point
445 })
446 })
447 .collect()
448 })
449 .collect();
450
451 Ok(query_response)
452 }
453}
454
455fn take_prefetched_source<T: Default>(items: &mut [T], index: usize) -> OperationResult<T> {
456 let source = items.get_mut(index).ok_or_else(|| {
457 OperationError::service_error(format!("prefetched source at index {index} does not exist"))
458 })?;
459
460 Ok(mem::take(source))
461}
462
463fn filter_by_point_ids(points: &[Vec<ScoredPoint>]) -> Filter {
465 let point_ids: AHashSet<_> = points.iter().flatten().map(|point| point.id).collect();
466
467 Filter::new_must(crate::segment::types::Condition::HasId(HasIdCondition::from(
469 point_ids,
470 )))
471}