Skip to main content

qdrant_edge/edge/
scroll.rs

1use std::collections::HashSet;
2use std::sync::atomic::AtomicBool;
3
4use crate::common::counter::hardware_accumulator::HwMeasurementAcc;
5use crate::common::types::DeferredBehavior;
6use itertools::Itertools as _;
7use rand::RngExt;
8use rand::distr::weighted::WeightedIndex;
9use rand::rngs::StdRng;
10use crate::segment::common::operation_error::{OperationError, OperationResult};
11use crate::segment::data_types::order_by::{Direction, OrderBy};
12use crate::segment::types::*;
13use crate::shard::query::scroll::{QueryScrollRequestInternal, ScrollOrder};
14use crate::shard::retrieve::record_internal::RecordInternal;
15use crate::shard::retrieve::retrieve_blocking::retrieve_blocking;
16use crate::shard::scroll::ScrollRequestInternal;
17
18use super::EdgeShard;
19use crate::edge::DEFAULT_EDGE_TIMEOUT;
20
21impl EdgeShard {
22    pub fn scroll(
23        &self,
24        request: ScrollRequestInternal,
25    ) -> OperationResult<(Vec<RecordInternal>, Option<PointIdType>)> {
26        let ScrollRequestInternal {
27            offset,
28            limit,
29            filter,
30            with_payload,
31            with_vector,
32            order_by,
33        } = request;
34
35        let limit = limit.unwrap_or(ScrollRequestInternal::default_limit());
36        let with_payload = with_payload.unwrap_or(ScrollRequestInternal::default_with_payload());
37
38        match order_by.map(OrderBy::from) {
39            None => {
40                let limit_plus_one = limit.saturating_add(1);
41                let mut records = self.scroll_by_id(
42                    offset,
43                    limit_plus_one,
44                    &with_payload,
45                    &with_vector,
46                    filter.as_ref(),
47                    HwMeasurementAcc::disposable_edge(),
48                )?;
49                let next_offset = if records.len() > limit {
50                    let last_record = records.pop().unwrap();
51                    Some(last_record.id)
52                } else {
53                    None
54                };
55                Ok((records, next_offset))
56            }
57
58            Some(order_by) => {
59                if offset.is_some() {
60                    return Err(OperationError::validation_error(
61                        "Offset is not supported when ordering by field",
62                    ));
63                }
64                let records = self.scroll_by_field(
65                    limit,
66                    &with_payload,
67                    &with_vector,
68                    filter.as_ref(),
69                    &order_by,
70                    HwMeasurementAcc::disposable_edge(),
71                )?;
72                Ok((records, None))
73            }
74        }
75    }
76
77    pub fn query_scroll(
78        &self,
79        request: &QueryScrollRequestInternal,
80    ) -> OperationResult<Vec<ScoredPoint>> {
81        let QueryScrollRequestInternal {
82            limit,
83            with_vector,
84            filter,
85            scroll_order,
86            with_payload,
87        } = request;
88
89        let records = match scroll_order {
90            ScrollOrder::ById => self.scroll_by_id(
91                None,
92                *limit,
93                with_payload,
94                with_vector,
95                filter.as_ref(),
96                HwMeasurementAcc::disposable_edge(),
97            )?,
98            ScrollOrder::ByField(order_by) => self.scroll_by_field(
99                *limit,
100                with_payload,
101                with_vector,
102                filter.as_ref(),
103                order_by,
104                HwMeasurementAcc::disposable_edge(),
105            )?,
106            ScrollOrder::Random => self.scroll_randomly(
107                *limit,
108                with_payload,
109                with_vector,
110                filter.as_ref(),
111                HwMeasurementAcc::disposable_edge(),
112            )?,
113        };
114
115        let point_results = records
116            .into_iter()
117            .map(|record| ScoredPoint {
118                id: record.id,
119                version: 0,
120                score: 1.0,
121                payload: record.payload,
122                vector: record.vector,
123                shard_key: record.shard_key,
124                order_value: record.order_value,
125            })
126            .collect();
127
128        Ok(point_results)
129    }
130
131    fn scroll_by_id(
132        &self,
133        offset: Option<ExtendedPointId>,
134        limit: usize,
135        with_payload_interface: &WithPayloadInterface,
136        with_vector: &WithVector,
137        filter: Option<&Filter>,
138        hw_measurement_acc: HwMeasurementAcc,
139    ) -> OperationResult<Vec<RecordInternal>> {
140        let (non_appendable, appendable) = self.segments.read().split_segments();
141        let hw_counter = hw_measurement_acc.get_counter_cell();
142
143        let point_ids: Vec<_> = non_appendable
144            .into_iter()
145            .chain(appendable)
146            .map(|segment| {
147                segment.get().read().read_filtered(
148                    offset,
149                    Some(limit),
150                    filter,
151                    &AtomicBool::new(false),
152                    &hw_counter,
153                    DeferredBehavior::Exclude,
154                )
155            })
156            .process_results(|iter| iter.flatten().sorted().dedup().take(limit).collect_vec())?;
157
158        let mut points = retrieve_blocking(
159            self.segments.clone(),
160            &point_ids,
161            &WithPayload::from(with_payload_interface),
162            with_vector,
163            DEFAULT_EDGE_TIMEOUT,
164            &AtomicBool::new(false),
165            hw_measurement_acc,
166            DeferredBehavior::Exclude,
167        )?;
168
169        let ordered_points = point_ids
170            .iter()
171            .filter_map(|point_id| points.remove(point_id))
172            .collect();
173
174        Ok(ordered_points)
175    }
176
177    fn scroll_by_field(
178        &self,
179        limit: usize,
180        with_payload_interface: &WithPayloadInterface,
181        with_vector: &WithVector,
182        filter: Option<&Filter>,
183        order_by: &OrderBy,
184        hw_measurement_acc: HwMeasurementAcc,
185    ) -> OperationResult<Vec<RecordInternal>> {
186        let (non_appendable, appendable) = self.segments.read().split_segments();
187        let hw_counter = hw_measurement_acc.get_counter_cell();
188
189        let read_results: Vec<_> = non_appendable
190            .into_iter()
191            .chain(appendable)
192            .map(|segment| {
193                segment.get().read().read_ordered_filtered(
194                    Some(limit),
195                    filter,
196                    order_by,
197                    &AtomicBool::new(false),
198                    &hw_counter,
199                    DeferredBehavior::Exclude,
200                )
201            })
202            .collect::<Result<_, _>>()?;
203
204        let (order_values, point_ids): (Vec<_>, Vec<_>) = read_results
205            .into_iter()
206            .kmerge_by(|a, b| match order_by.direction() {
207                Direction::Asc => a <= b,
208                Direction::Desc => a >= b,
209            })
210            .dedup()
211            .take(limit)
212            .unzip();
213
214        let points = retrieve_blocking(
215            self.segments.clone(),
216            &point_ids,
217            &WithPayload::from(with_payload_interface),
218            with_vector,
219            DEFAULT_EDGE_TIMEOUT,
220            &AtomicBool::new(false),
221            hw_measurement_acc,
222            DeferredBehavior::Exclude,
223        )?;
224
225        let ordered_points = point_ids
226            .iter()
227            .zip(order_values)
228            .filter_map(|(point_id, value)| {
229                let mut record = points.get(point_id).cloned()?;
230                record.order_value = Some(value);
231                Some(record)
232            })
233            .collect();
234
235        Ok(ordered_points)
236    }
237
238    fn scroll_randomly(
239        &self,
240        limit: usize,
241        with_payload_interface: &WithPayloadInterface,
242        with_vector: &WithVector,
243        filter: Option<&Filter>,
244        hw_measurement_acc: HwMeasurementAcc,
245    ) -> OperationResult<Vec<RecordInternal>> {
246        let (non_appendable, appendable) = self.segments.read().split_segments();
247        let hw_counter = hw_measurement_acc.get_counter_cell();
248
249        let (point_count, mut point_ids): (Vec<_>, Vec<_>) = non_appendable
250            .into_iter()
251            .chain(appendable)
252            .map(|segment| {
253                let segment = segment.get();
254                let segment = segment.read();
255
256                let point_count = segment.available_point_count_without_deferred();
257                let point_ids = segment.read_random_filtered(
258                    limit,
259                    filter,
260                    &AtomicBool::new(false),
261                    &hw_counter,
262                )?;
263
264                OperationResult::Ok((point_count, point_ids))
265            })
266            .process_results(|iter| iter.unzip())?;
267
268        // Shortcut if all segments are empty
269        if point_count.iter().all(|&count| count == 0) {
270            return Ok(Vec::new());
271        }
272
273        // Select points in a weighted fashion from each segment, depending on how many points each segment has.
274        let distribution = WeightedIndex::new(point_count).map_err(|err| {
275            OperationError::service_error(format!(
276                "failed to create weighted index for random scroll: {err:?}"
277            ))
278        })?;
279
280        let mut rng = rand::make_rng::<StdRng>();
281        let mut random_point_ids = HashSet::with_capacity(limit);
282
283        // Randomly sample points in two stages
284        //
285        // 1. This loop iterates <= LIMIT times, and either breaks early if we
286        // have enough points, or if some of the segments are exhausted.
287        //
288        // 2. If the segments are exhausted, we will fill up the rest of the
289        // points from other segments. In total, the complexity is guaranteed to
290        // be O(limit).
291        while random_point_ids.len() < limit {
292            let segment_idx = rng.sample(&distribution);
293            let segment_point_ids = &mut point_ids[segment_idx];
294
295            if let Some(point) = segment_point_ids.pop() {
296                random_point_ids.insert(point);
297            } else {
298                // It seems that some segments are empty early,
299                // so distribution does not make sense anymore.
300                // This is only possible if segments size < limit.
301                break;
302            }
303        }
304
305        // If we still need more points, we will get them from the rest of the segments.
306        // This is a rare case, as it seems we don't have enough points in individual segments.
307        // Therefore, we can ignore "proper" distribution, as it won't be accurate anyway.
308        if random_point_ids.len() < limit {
309            for point_id in point_ids.into_iter().flatten() {
310                random_point_ids.insert(point_id);
311                if random_point_ids.len() >= limit {
312                    break;
313                }
314            }
315        }
316
317        let random_point_ids: Vec<_> = random_point_ids.into_iter().collect();
318
319        let random_points = retrieve_blocking(
320            self.segments.clone(),
321            &random_point_ids,
322            &WithPayload::from(with_payload_interface),
323            with_vector,
324            DEFAULT_EDGE_TIMEOUT,
325            &AtomicBool::new(false),
326            hw_measurement_acc,
327            DeferredBehavior::Exclude,
328        )?
329        .into_values()
330        .collect();
331
332        Ok(random_points)
333    }
334}