1use std::collections::HashSet;
2use std::sync::atomic::AtomicBool;
3
4use crate::common::counter::hardware_accumulator::HwMeasurementAcc;
5use crate::common::types::DeferredBehavior;
6use itertools::Itertools as _;
7use rand::RngExt;
8use rand::distr::weighted::WeightedIndex;
9use rand::rngs::StdRng;
10use crate::segment::common::operation_error::{OperationError, OperationResult};
11use crate::segment::data_types::order_by::{Direction, OrderBy};
12use crate::segment::types::*;
13use crate::shard::query::scroll::{QueryScrollRequestInternal, ScrollOrder};
14use crate::shard::retrieve::record_internal::RecordInternal;
15use crate::shard::retrieve::retrieve_blocking::retrieve_blocking;
16use crate::shard::scroll::ScrollRequestInternal;
17
18use super::EdgeShard;
19use crate::edge::DEFAULT_EDGE_TIMEOUT;
20
21impl EdgeShard {
22 pub fn scroll(
23 &self,
24 request: ScrollRequestInternal,
25 ) -> OperationResult<(Vec<RecordInternal>, Option<PointIdType>)> {
26 let ScrollRequestInternal {
27 offset,
28 limit,
29 filter,
30 with_payload,
31 with_vector,
32 order_by,
33 } = request;
34
35 let limit = limit.unwrap_or(ScrollRequestInternal::default_limit());
36 let with_payload = with_payload.unwrap_or(ScrollRequestInternal::default_with_payload());
37
38 match order_by.map(OrderBy::from) {
39 None => {
40 let limit_plus_one = limit.saturating_add(1);
41 let mut records = self.scroll_by_id(
42 offset,
43 limit_plus_one,
44 &with_payload,
45 &with_vector,
46 filter.as_ref(),
47 HwMeasurementAcc::disposable_edge(),
48 )?;
49 let next_offset = if records.len() > limit {
50 let last_record = records.pop().unwrap();
51 Some(last_record.id)
52 } else {
53 None
54 };
55 Ok((records, next_offset))
56 }
57
58 Some(order_by) => {
59 if offset.is_some() {
60 return Err(OperationError::validation_error(
61 "Offset is not supported when ordering by field",
62 ));
63 }
64 let records = self.scroll_by_field(
65 limit,
66 &with_payload,
67 &with_vector,
68 filter.as_ref(),
69 &order_by,
70 HwMeasurementAcc::disposable_edge(),
71 )?;
72 Ok((records, None))
73 }
74 }
75 }
76
77 pub fn query_scroll(
78 &self,
79 request: &QueryScrollRequestInternal,
80 ) -> OperationResult<Vec<ScoredPoint>> {
81 let QueryScrollRequestInternal {
82 limit,
83 with_vector,
84 filter,
85 scroll_order,
86 with_payload,
87 } = request;
88
89 let records = match scroll_order {
90 ScrollOrder::ById => self.scroll_by_id(
91 None,
92 *limit,
93 with_payload,
94 with_vector,
95 filter.as_ref(),
96 HwMeasurementAcc::disposable_edge(),
97 )?,
98 ScrollOrder::ByField(order_by) => self.scroll_by_field(
99 *limit,
100 with_payload,
101 with_vector,
102 filter.as_ref(),
103 order_by,
104 HwMeasurementAcc::disposable_edge(),
105 )?,
106 ScrollOrder::Random => self.scroll_randomly(
107 *limit,
108 with_payload,
109 with_vector,
110 filter.as_ref(),
111 HwMeasurementAcc::disposable_edge(),
112 )?,
113 };
114
115 let point_results = records
116 .into_iter()
117 .map(|record| ScoredPoint {
118 id: record.id,
119 version: 0,
120 score: 1.0,
121 payload: record.payload,
122 vector: record.vector,
123 shard_key: record.shard_key,
124 order_value: record.order_value,
125 })
126 .collect();
127
128 Ok(point_results)
129 }
130
131 fn scroll_by_id(
132 &self,
133 offset: Option<ExtendedPointId>,
134 limit: usize,
135 with_payload_interface: &WithPayloadInterface,
136 with_vector: &WithVector,
137 filter: Option<&Filter>,
138 hw_measurement_acc: HwMeasurementAcc,
139 ) -> OperationResult<Vec<RecordInternal>> {
140 let (non_appendable, appendable) = self.segments.read().split_segments();
141 let hw_counter = hw_measurement_acc.get_counter_cell();
142
143 let point_ids: Vec<_> = non_appendable
144 .into_iter()
145 .chain(appendable)
146 .map(|segment| {
147 segment.get().read().read_filtered(
148 offset,
149 Some(limit),
150 filter,
151 &AtomicBool::new(false),
152 &hw_counter,
153 DeferredBehavior::Exclude,
154 )
155 })
156 .process_results(|iter| iter.flatten().sorted().dedup().take(limit).collect_vec())?;
157
158 let mut points = retrieve_blocking(
159 self.segments.clone(),
160 &point_ids,
161 &WithPayload::from(with_payload_interface),
162 with_vector,
163 DEFAULT_EDGE_TIMEOUT,
164 &AtomicBool::new(false),
165 hw_measurement_acc,
166 DeferredBehavior::Exclude,
167 )?;
168
169 let ordered_points = point_ids
170 .iter()
171 .filter_map(|point_id| points.remove(point_id))
172 .collect();
173
174 Ok(ordered_points)
175 }
176
177 fn scroll_by_field(
178 &self,
179 limit: usize,
180 with_payload_interface: &WithPayloadInterface,
181 with_vector: &WithVector,
182 filter: Option<&Filter>,
183 order_by: &OrderBy,
184 hw_measurement_acc: HwMeasurementAcc,
185 ) -> OperationResult<Vec<RecordInternal>> {
186 let (non_appendable, appendable) = self.segments.read().split_segments();
187 let hw_counter = hw_measurement_acc.get_counter_cell();
188
189 let read_results: Vec<_> = non_appendable
190 .into_iter()
191 .chain(appendable)
192 .map(|segment| {
193 segment.get().read().read_ordered_filtered(
194 Some(limit),
195 filter,
196 order_by,
197 &AtomicBool::new(false),
198 &hw_counter,
199 DeferredBehavior::Exclude,
200 )
201 })
202 .collect::<Result<_, _>>()?;
203
204 let (order_values, point_ids): (Vec<_>, Vec<_>) = read_results
205 .into_iter()
206 .kmerge_by(|a, b| match order_by.direction() {
207 Direction::Asc => a <= b,
208 Direction::Desc => a >= b,
209 })
210 .dedup()
211 .take(limit)
212 .unzip();
213
214 let points = retrieve_blocking(
215 self.segments.clone(),
216 &point_ids,
217 &WithPayload::from(with_payload_interface),
218 with_vector,
219 DEFAULT_EDGE_TIMEOUT,
220 &AtomicBool::new(false),
221 hw_measurement_acc,
222 DeferredBehavior::Exclude,
223 )?;
224
225 let ordered_points = point_ids
226 .iter()
227 .zip(order_values)
228 .filter_map(|(point_id, value)| {
229 let mut record = points.get(point_id).cloned()?;
230 record.order_value = Some(value);
231 Some(record)
232 })
233 .collect();
234
235 Ok(ordered_points)
236 }
237
238 fn scroll_randomly(
239 &self,
240 limit: usize,
241 with_payload_interface: &WithPayloadInterface,
242 with_vector: &WithVector,
243 filter: Option<&Filter>,
244 hw_measurement_acc: HwMeasurementAcc,
245 ) -> OperationResult<Vec<RecordInternal>> {
246 let (non_appendable, appendable) = self.segments.read().split_segments();
247 let hw_counter = hw_measurement_acc.get_counter_cell();
248
249 let (point_count, mut point_ids): (Vec<_>, Vec<_>) = non_appendable
250 .into_iter()
251 .chain(appendable)
252 .map(|segment| {
253 let segment = segment.get();
254 let segment = segment.read();
255
256 let point_count = segment.available_point_count_without_deferred();
257 let point_ids = segment.read_random_filtered(
258 limit,
259 filter,
260 &AtomicBool::new(false),
261 &hw_counter,
262 )?;
263
264 OperationResult::Ok((point_count, point_ids))
265 })
266 .process_results(|iter| iter.unzip())?;
267
268 if point_count.iter().all(|&count| count == 0) {
270 return Ok(Vec::new());
271 }
272
273 let distribution = WeightedIndex::new(point_count).map_err(|err| {
275 OperationError::service_error(format!(
276 "failed to create weighted index for random scroll: {err:?}"
277 ))
278 })?;
279
280 let mut rng = rand::make_rng::<StdRng>();
281 let mut random_point_ids = HashSet::with_capacity(limit);
282
283 while random_point_ids.len() < limit {
292 let segment_idx = rng.sample(&distribution);
293 let segment_point_ids = &mut point_ids[segment_idx];
294
295 if let Some(point) = segment_point_ids.pop() {
296 random_point_ids.insert(point);
297 } else {
298 break;
302 }
303 }
304
305 if random_point_ids.len() < limit {
309 for point_id in point_ids.into_iter().flatten() {
310 random_point_ids.insert(point_id);
311 if random_point_ids.len() >= limit {
312 break;
313 }
314 }
315 }
316
317 let random_point_ids: Vec<_> = random_point_ids.into_iter().collect();
318
319 let random_points = retrieve_blocking(
320 self.segments.clone(),
321 &random_point_ids,
322 &WithPayload::from(with_payload_interface),
323 with_vector,
324 DEFAULT_EDGE_TIMEOUT,
325 &AtomicBool::new(false),
326 hw_measurement_acc,
327 DeferredBehavior::Exclude,
328 )?
329 .into_values()
330 .collect();
331
332 Ok(random_points)
333 }
334}