qdrant_edge/segment/vector_storage/query/
discover_query.rs1use std::hash::Hash;
2use std::iter;
3
4use crate::common::math::scaled_fast_sigmoid;
5use crate::common::types::ScoreType;
6use itertools::Itertools;
7use serde::Serialize;
8
9use super::context_query::ContextPair;
10use super::{Query, TransformInto};
11use crate::segment::common::operation_error::OperationResult;
12use crate::segment::data_types::vectors::{QueryVector, VectorInternal};
13
14type RankType = i32;
15
16impl<T> ContextPair<T> {
17 fn rank_by(&self, similarity: impl Fn(&T) -> ScoreType) -> RankType {
19 let positive_similarity = similarity(&self.positive);
20 let negative_similarity = similarity(&self.negative);
21
22 positive_similarity.total_cmp(&negative_similarity) as RankType
24 }
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
28pub struct DiscoverQuery<T> {
29 pub target: T,
30 pub pairs: Vec<ContextPair<T>>,
31}
32
33impl<T> DiscoverQuery<T> {
34 pub fn new(target: T, pairs: Vec<ContextPair<T>>) -> Self {
35 Self { target, pairs }
36 }
37
38 pub fn flat_iter(&self) -> impl Iterator<Item = &T> {
39 let pairs_iter = self.pairs.iter().flat_map(|pair| pair.iter());
40
41 iter::once(&self.target).chain(pairs_iter)
42 }
43
44 fn rank_by(&self, similarity: impl Fn(&T) -> ScoreType) -> RankType {
45 self.pairs
46 .iter()
47 .map(|pair| pair.rank_by(&similarity))
48 .sum()
50 }
51}
52
53impl<T, U> TransformInto<DiscoverQuery<U>, T, U> for DiscoverQuery<T> {
54 fn transform<F>(self, mut f: F) -> OperationResult<DiscoverQuery<U>>
55 where
56 F: FnMut(T) -> OperationResult<U>,
57 {
58 Ok(DiscoverQuery::new(
59 f(self.target)?,
60 self.pairs
61 .into_iter()
62 .map(|pair| pair.transform(&mut f))
63 .try_collect()?,
64 ))
65 }
66}
67
68impl<T> Query<T> for DiscoverQuery<T> {
69 fn score_by(&self, similarity: impl Fn(&T) -> ScoreType) -> ScoreType {
70 let rank = self.rank_by(&similarity);
71
72 let target_similarity = similarity(&self.target);
73 let sigmoid_similarity = scaled_fast_sigmoid(target_similarity);
74
75 rank as ScoreType + sigmoid_similarity
76 }
77}
78
79impl From<DiscoverQuery<VectorInternal>> for QueryVector {
80 fn from(query: DiscoverQuery<VectorInternal>) -> Self {
81 QueryVector::Discover(query)
82 }
83}
84
85#[cfg(test)]
86mod test {
87 use std::cmp::Ordering;
88
89 use crate::common::types::ScoreType;
90 use itertools::Itertools;
91 use proptest::prelude::*;
92 use rstest::rstest;
93
94 use super::*;
95
96 fn dummy_similarity(x: &isize) -> ScoreType {
97 *x as ScoreType
98 }
99
100 #[rstest]
103 #[case::no_pairs(vec![], 0)]
104 #[case::closer_to_positive(vec![(10, 4)], 1)]
105 #[case::closer_to_negative(vec![(4, 10)], -1)]
106 #[case::equal_scores(vec![(11, 11)], 0)]
107 #[case::neutral_zone(vec![(10, 4), (4, 10)], 0)]
108 #[case::best_zone(vec![(10, 4), (4, 2)], 2)]
109 #[case::worst_zone(vec![(4, 10), (2, 4)], -2)]
110 #[case::many_pairs(vec![(1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (0, 4)], 4)]
111 fn context_ranking(#[case] pairs: Vec<(isize, isize)>, #[case] expected: RankType) {
112 let pairs = pairs.into_iter().map(ContextPair::from).collect();
113
114 let target = 42;
115
116 let query = DiscoverQuery::new(target, pairs);
117
118 let rank = query.rank_by(dummy_similarity);
119
120 assert_eq!(
121 rank, expected,
122 "Ranking is incorrect, expected {expected}, but got {rank}"
123 );
124 }
125
126 #[rstest]
128 #[case::no_pairs(1, vec![], Ordering::Less)]
129 #[case::just_above(1, vec![(1,0),(1,0)], Ordering::Greater)]
130 #[case::just_below(-1, vec![(1,0),(1,0)], Ordering::Less)]
131 #[case::bad_target_good_context(-1000, vec![(1,0),(1,0),(1, 0)], Ordering::Greater)]
132 #[case::good_target_bad_context(1000, vec![(1,0),(0,1)], Ordering::Less)]
133 fn score_better(
134 #[case] target: isize,
135 #[case] pairs: Vec<(isize, isize)>,
136 #[case] expected: Ordering,
137 ) {
138 let fixed_score: f32 = 2.5;
139
140 let pairs = pairs.into_iter().map(ContextPair::from).collect();
141
142 let query = DiscoverQuery::new(target, pairs);
143
144 let score = query.score_by(dummy_similarity);
145
146 assert_eq!(
147 score.total_cmp(&fixed_score),
148 expected,
149 "Comparison is incorrect, expected {expected:?} for {score} against {fixed_score}"
150 );
151 }
152
153 proptest! {
154 #[test]
155 fn same_target_only_changes_rank(
156 target in -1000f32..1000f32,
157 pairs1 in prop::collection::vec((0f32..1000f32, 0.0f32..1000f32), 0..10),
158 pairs2 in prop::collection::vec((0f32..1000f32, 0.0f32..1000f32), 0..10),
159 ) {
160 let dummy_similarity = |x: &ScoreType| *x as ScoreType;
161
162 let pairs1 = pairs1.into_iter().map(ContextPair::from).collect();
163 let query1 = DiscoverQuery::new(target, pairs1);
164 let score1 = query1.score_by(dummy_similarity);
165
166 let pairs2 = pairs2.into_iter().map(ContextPair::from).collect();
167 let query2 = DiscoverQuery::new(target, pairs2);
168 let score2 = query2.score_by(dummy_similarity);
169
170 let target_part1 = score1 - score1.floor();
171 let target_part2 = score2 - score2.floor();
172
173 assert!((target_part1 - target_part2).abs() <= 1.0e-6, "Target part of score is not similar, score1: {score1}, score2: {score2}");
174 }
175
176 #[test]
177 fn same_context_only_changes_target(
178 target1 in -1000f32..1000f32,
179 target2 in -1000f32..1000f32,
180 pairs in prop::collection::vec((0f32..1000f32, 0.0f32..1000f32), 0..10),
181 )
182 {
183 let dummy_similarity = |x: &ScoreType| *x as ScoreType;
184
185 let pairs = pairs.into_iter().map(ContextPair::from).collect_vec();
186 let query1 = DiscoverQuery::new(target1, pairs.clone());
187 let score1 = query1.score_by(dummy_similarity);
188
189 let query2 = DiscoverQuery::new(target2, pairs);
190 let score2 = query2.score_by(dummy_similarity);
191
192 let context_part1 = score1.floor();
193 let context_part2 = score2.floor();
194
195 assert_eq!(context_part1, context_part2,"Context part of score isn't equal, score1: {score1}, score2: {score2}");
196 }
197 }
198}