Skip to main content

qdrant_edge/edge/
query.rs

1use std::mem;
2use std::sync::Arc;
3use std::sync::atomic::AtomicBool;
4
5use ahash::AHashSet;
6use crate::common::counter::hardware_accumulator::HwMeasurementAcc;
7use crate::common::types::{DeferredBehavior, ScoreType};
8use ordered_float::OrderedFloat;
9use crate::segment::common::operation_error::{OperationError, OperationResult};
10use crate::segment::common::reciprocal_rank_fusion::rrf_scoring;
11use crate::segment::common::score_fusion::{ScoreFusion, score_fusion};
12use crate::segment::data_types::query_context::FormulaContext;
13use crate::segment::index::query_optimization::rescore_formula::parsed_formula::ParsedFormula;
14use crate::segment::types::{
15    Filter, HasIdCondition, ScoredPoint, WithPayload, WithPayloadInterface, WithVector,
16};
17use crate::shard::query::mmr::mmr_from_points_with_vector;
18use crate::shard::query::planned_query::*;
19use crate::shard::query::scroll::{QueryScrollRequestInternal, ScrollOrder};
20use crate::shard::query::*;
21use crate::shard::retrieve::retrieve_blocking::retrieve_blocking;
22use crate::shard::search::CoreSearchRequest;
23use crate::shard::search_result_aggregator::BatchResultAggregator;
24
25use super::EdgeShard;
26use crate::edge::DEFAULT_EDGE_TIMEOUT;
27
28impl EdgeShard {
29    pub fn query(&self, request: ShardQueryRequest) -> OperationResult<Vec<ScoredPoint>> {
30        let planned_query = PlannedQuery::try_from(vec![request])?;
31
32        let PlannedQuery {
33            root_plans,
34            searches,
35            scrolls,
36        } = planned_query;
37
38        let mut search_results = Vec::new();
39        for search in &searches {
40            search_results.push(self.search(search.clone())?);
41        }
42
43        let mut scroll_results = Vec::new();
44        for scroll in &scrolls {
45            scroll_results.push(self.query_scroll(scroll)?);
46        }
47
48        let mut scored_points_batch = Vec::new();
49        for root_plan in root_plans {
50            let scored_points = self.resolve_plan(
51                root_plan,
52                &mut search_results,
53                &mut scroll_results,
54                HwMeasurementAcc::disposable_edge(),
55            )?;
56
57            scored_points_batch.push(scored_points)
58        }
59
60        let [scored_points] = scored_points_batch
61            .try_into()
62            .map_err(|unconverted: Vec<_>| {
63                OperationError::service_error(format!(
64                    "unexpected scored points batch size: expected 1, received {}",
65                    unconverted.len(),
66                ))
67            })?;
68
69        Ok(scored_points)
70    }
71
72    fn resolve_plan(
73        &self,
74        root_plan: RootPlan,
75        search_results: &mut Vec<Vec<ScoredPoint>>,
76        scroll_results: &mut Vec<Vec<ScoredPoint>>,
77        hw_measurement_acc: HwMeasurementAcc,
78    ) -> OperationResult<Vec<ScoredPoint>> {
79        let RootPlan {
80            merge_plan,
81            with_payload,
82            with_vector,
83        } = root_plan;
84
85        let results = self.recurse_prefetch(
86            merge_plan,
87            search_results,
88            scroll_results,
89            0,
90            hw_measurement_acc.clone(),
91        )?;
92
93        let [result] = self
94            .fill_with_payload_or_vectors(
95                vec![results],
96                with_payload,
97                with_vector,
98                hw_measurement_acc,
99            )?
100            .try_into()
101            .map_err(|unconverted: Vec<_>| {
102                OperationError::service_error(format!(
103                    "expected single result after filling payload/vectors, got {}",
104                    unconverted.len(),
105                ))
106            })?;
107        Ok(result)
108    }
109
110    fn recurse_prefetch(
111        &self,
112        merge_plan: MergePlan,
113        search_results: &mut Vec<Vec<ScoredPoint>>,
114        scroll_results: &mut Vec<Vec<ScoredPoint>>,
115        depth: usize,
116        hw_counter_acc: HwMeasurementAcc,
117    ) -> OperationResult<Vec<ScoredPoint>> {
118        let MergePlan {
119            sources: merge_plan_sources,
120            rescore_stages,
121        } = merge_plan;
122
123        let max_len = merge_plan_sources.len();
124        let mut sources = Vec::with_capacity(max_len);
125
126        // We need to preserve the order of the sources for some fusion strategies
127        for source in merge_plan_sources {
128            match source {
129                Source::SearchesIdx(idx) => {
130                    sources.push(take_prefetched_source(search_results, idx)?)
131                }
132
133                Source::ScrollsIdx(idx) => {
134                    sources.push(take_prefetched_source(scroll_results, idx)?)
135                }
136
137                Source::Prefetch(merge_plan) => {
138                    let merged = self.recurse_prefetch(
139                        *merge_plan,
140                        search_results,
141                        scroll_results,
142                        depth + 1,
143                        hw_counter_acc.clone(),
144                    )?;
145
146                    sources.push(merged);
147                }
148            }
149        }
150
151        if let Some(rescore_stages) = rescore_stages {
152            let RescoreStages {
153                shard_level,
154                collection_level,
155            } = rescore_stages;
156
157            let shard_stage_result = if let Some(rescore_params) = shard_level {
158                vec![self.rescore(sources, rescore_params, hw_counter_acc.clone())?]
159            } else {
160                sources
161            };
162
163            let collection_result = if let Some(rescore_params) = collection_level {
164                self.rescore(shard_stage_result, rescore_params, hw_counter_acc)?
165            } else {
166                // Only one shard result is expected at this point.
167                shard_stage_result.into_iter().next().unwrap_or_default()
168            };
169
170            // In Edge, both shard-level and collection-level rescoring are handled the same way.
171            Ok(collection_result)
172        } else {
173            // The sources here are passed to the next layer without any extra processing.
174            // It should be a query without prefetches.
175            debug_assert_eq!(depth, 0);
176            debug_assert_eq!(sources.len(), 1);
177            let [result] = sources.try_into().map_err(|unconverted: Vec<_>| {
178                OperationError::service_error(format!(
179                    "expected single source without rescore stages, got {}",
180                    unconverted.len(),
181                ))
182            })?;
183
184            Ok(result)
185        }
186    }
187
188    fn rescore(
189        &self,
190        sources: Vec<Vec<ScoredPoint>>,
191        rescore_params: RescoreParams,
192        hw_counter_acc: HwMeasurementAcc,
193    ) -> OperationResult<Vec<ScoredPoint>> {
194        let RescoreParams {
195            rescore,
196            score_threshold,
197            limit,
198            params,
199        } = rescore_params;
200
201        match rescore {
202            ScoringQuery::Fusion(fusion) => {
203                let top_fused = Self::fusion_rescore(
204                    sources,
205                    fusion,
206                    score_threshold.map(OrderedFloat::into_inner),
207                    limit,
208                )?;
209                Ok(top_fused)
210            }
211
212            ScoringQuery::OrderBy(order_by) => {
213                // create single scroll request for rescoring query
214                let filter = filter_by_point_ids(&sources);
215
216                // Note: score_threshold is not used in this case, as all results will have same score,
217                // but different order_value
218                let scroll_request = QueryScrollRequestInternal {
219                    limit,
220                    filter: Some(filter),
221                    with_payload: false.into(),
222                    with_vector: false.into(),
223                    scroll_order: ScrollOrder::ByField(order_by),
224                };
225
226                self.query_scroll(&scroll_request)
227            }
228
229            ScoringQuery::Vector(query_enum) => {
230                // create single search request for rescoring query
231                let filter = filter_by_point_ids(&sources);
232
233                let search_request = CoreSearchRequest {
234                    query: query_enum,
235                    filter: Some(filter),
236                    params,
237                    limit,
238                    offset: 0,
239                    with_payload: None,
240                    with_vector: None,
241                    score_threshold: score_threshold.map(OrderedFloat::into_inner),
242                };
243
244                self.search(search_request)
245            }
246
247            ScoringQuery::Formula(formula) => self.rescore_with_formula(
248                formula,
249                sources,
250                limit,
251                score_threshold.map(OrderedFloat::into_inner),
252                hw_counter_acc,
253            ),
254
255            ScoringQuery::Sample(sample) => match sample {
256                SampleInternal::Random => {
257                    // create single scroll request for rescoring query
258                    let filter = filter_by_point_ids(&sources);
259
260                    // Note: score_threshold is not used in this case, as all results will have same score and order_value
261                    let scroll_request = QueryScrollRequestInternal {
262                        limit,
263                        filter: Some(filter),
264                        with_payload: false.into(),
265                        with_vector: false.into(),
266                        scroll_order: ScrollOrder::Random,
267                    };
268
269                    self.query_scroll(&scroll_request)
270                }
271            },
272
273            ScoringQuery::Mmr(mmr) => self.mmr_rescore(sources, mmr, limit, hw_counter_acc),
274        }
275    }
276
277    fn fusion_rescore(
278        sources: Vec<Vec<ScoredPoint>>,
279        fusion: FusionInternal,
280        score_threshold: Option<f32>,
281        limit: usize,
282    ) -> OperationResult<Vec<ScoredPoint>> {
283        let fused = match fusion {
284            FusionInternal::Rrf { k, ref weights } => {
285                let weights_slice = weights
286                    .as_ref()
287                    .map(|w| w.iter().map(|f| f.into_inner()).collect::<Vec<_>>());
288                rrf_scoring(sources, k, weights_slice.as_deref())?
289            }
290            FusionInternal::Dbsf => score_fusion(sources, ScoreFusion::dbsf()),
291        };
292
293        let top_fused: Vec<_> = if let Some(score_threshold) = score_threshold {
294            fused
295                .into_iter()
296                .take_while(|point| point.score >= score_threshold)
297                .take(limit)
298                .collect()
299        } else {
300            fused.into_iter().take(limit).collect()
301        };
302
303        Ok(top_fused)
304    }
305
306    pub fn rescore_with_formula(
307        &self,
308        formula: ParsedFormula,
309        prefetches_results: Vec<Vec<ScoredPoint>>,
310        limit: usize,
311        score_threshold: Option<ScoreType>,
312        hw_measurement_acc: HwMeasurementAcc,
313    ) -> OperationResult<Vec<ScoredPoint>> {
314        let ctx = FormulaContext {
315            formula,
316            prefetches_results,
317            limit,
318            score_threshold,
319            is_stopped: Arc::new(AtomicBool::new(false)),
320        };
321
322        let ctx = Arc::new(ctx);
323        let hw_counter = hw_measurement_acc.get_counter_cell();
324
325        let mut rescored_results = Vec::new();
326
327        // Collect the segments first so we don't lock the segment holder during the operations.
328        let segments = self
329            .segments
330            .read()
331            .non_appendable_then_appendable_segments()
332            .collect::<Vec<_>>();
333
334        for segment in segments {
335            let rescored_result = segment
336                .get()
337                .read()
338                .rescore_with_formula(ctx.clone(), &hw_counter)?;
339
340            rescored_results.push(rescored_result);
341        }
342
343        // use aggregator with only one "batch"
344        let mut aggregator = BatchResultAggregator::new(std::iter::once(limit));
345        aggregator.update_point_versions(rescored_results.iter().flatten());
346        aggregator.update_batch_results(0, rescored_results.into_iter().flatten());
347
348        let top =
349            aggregator.into_topk().into_iter().next().ok_or_else(|| {
350                OperationError::service_error("expected first result of aggregator")
351            })?;
352
353        Ok(top)
354    }
355
356    /// Maximal Marginal Relevance rescoring
357    fn mmr_rescore(
358        &self,
359        sources: Vec<Vec<ScoredPoint>>,
360        mmr: MmrInternal,
361        limit: usize,
362        hw_measurement_acc: HwMeasurementAcc,
363    ) -> OperationResult<Vec<ScoredPoint>> {
364        let points_with_vector = self
365            .fill_with_payload_or_vectors(
366                sources,
367                false.into(),
368                WithVector::from(mmr.using.clone()),
369                hw_measurement_acc.clone(),
370            )?
371            .into_iter()
372            .flatten();
373
374        let vector_data_config = self
375            .config
376            .read()
377            .vector_data_config(&mmr.using)
378            .ok_or_else(|| {
379                OperationError::service_error(format!(
380                    "vector data config for vector {} not found",
381                    mmr.using,
382                ))
383            })?;
384
385        // Even if we have fewer points than requested, still calculate MMR.
386        let mut top_mmr = mmr_from_points_with_vector(
387            points_with_vector,
388            mmr,
389            vector_data_config.distance,
390            vector_data_config.multivector_config,
391            limit,
392            hw_measurement_acc,
393        )?;
394
395        // strip mmr vector. We will handle user-requested vectors at root level of request.
396        for point in &mut top_mmr {
397            point.vector = None;
398        }
399
400        Ok(top_mmr)
401    }
402
403    /// This function always filters deferred points.
404    fn fill_with_payload_or_vectors(
405        &self,
406        query_response: ShardQueryResponse,
407        with_payload: WithPayloadInterface,
408        with_vector: WithVector,
409        hw_measurement_acc: HwMeasurementAcc,
410    ) -> OperationResult<ShardQueryResponse> {
411        if !with_payload.is_required() && !with_vector.is_enabled() {
412            return Ok(query_response);
413        }
414
415        // ids to retrieve (deduplication happens in the searcher)
416        let point_ids: Vec<_> = query_response
417            .iter()
418            .flatten()
419            .map(|scored_point| scored_point.id)
420            .collect();
421
422        let records_map = retrieve_blocking(
423            self.segments.clone(),
424            &point_ids,
425            &WithPayload::from(with_payload),
426            &with_vector,
427            DEFAULT_EDGE_TIMEOUT,
428            &AtomicBool::new(false),
429            hw_measurement_acc,
430            DeferredBehavior::Exclude,
431        )?;
432
433        // It might be possible, that we won't find all records,
434        // so we need to re-collect the results
435        let query_response: ShardQueryResponse = query_response
436            .into_iter()
437            .map(|points| {
438                points
439                    .into_iter()
440                    .filter_map(|mut point| {
441                        records_map.get(&point.id).map(|record| {
442                            point.payload.clone_from(&record.payload);
443                            point.vector.clone_from(&record.vector);
444                            point
445                        })
446                    })
447                    .collect()
448            })
449            .collect();
450
451        Ok(query_response)
452    }
453}
454
455fn take_prefetched_source<T: Default>(items: &mut [T], index: usize) -> OperationResult<T> {
456    let source = items.get_mut(index).ok_or_else(|| {
457        OperationError::service_error(format!("prefetched source at index {index} does not exist"))
458    })?;
459
460    Ok(mem::take(source))
461}
462
463/// Extracts point ids from sources, and creates a filter to only include those ids.
464fn filter_by_point_ids(points: &[Vec<ScoredPoint>]) -> Filter {
465    let point_ids: AHashSet<_> = points.iter().flatten().map(|point| point.id).collect();
466
467    // create filter for target point ids
468    Filter::new_must(crate::segment::types::Condition::HasId(HasIdCondition::from(
469        point_ids,
470    )))
471}