Skip to main content

qdrant_edge/segment/vector_storage/query/
discover_query.rs

1use std::hash::Hash;
2use std::iter;
3
4use crate::common::math::scaled_fast_sigmoid;
5use crate::common::types::ScoreType;
6use itertools::Itertools;
7use serde::Serialize;
8
9use super::context_query::ContextPair;
10use super::{Query, TransformInto};
11use crate::segment::common::operation_error::OperationResult;
12use crate::segment::data_types::vectors::{QueryVector, VectorInternal};
13
14type RankType = i32;
15
16impl<T> ContextPair<T> {
17    /// Calculates on which side of the space the point is, with respect to this pair
18    fn rank_by(&self, similarity: impl Fn(&T) -> ScoreType) -> RankType {
19        let positive_similarity = similarity(&self.positive);
20        let negative_similarity = similarity(&self.negative);
21
22        // if closer to positive, return 1, else -1
23        positive_similarity.total_cmp(&negative_similarity) as RankType
24    }
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
28pub struct DiscoverQuery<T> {
29    pub target: T,
30    pub pairs: Vec<ContextPair<T>>,
31}
32
33impl<T> DiscoverQuery<T> {
34    pub fn new(target: T, pairs: Vec<ContextPair<T>>) -> Self {
35        Self { target, pairs }
36    }
37
38    pub fn flat_iter(&self) -> impl Iterator<Item = &T> {
39        let pairs_iter = self.pairs.iter().flat_map(|pair| pair.iter());
40
41        iter::once(&self.target).chain(pairs_iter)
42    }
43
44    fn rank_by(&self, similarity: impl Fn(&T) -> ScoreType) -> RankType {
45        self.pairs
46            .iter()
47            .map(|pair| pair.rank_by(&similarity))
48            // get overall rank
49            .sum()
50    }
51}
52
53impl<T, U> TransformInto<DiscoverQuery<U>, T, U> for DiscoverQuery<T> {
54    fn transform<F>(self, mut f: F) -> OperationResult<DiscoverQuery<U>>
55    where
56        F: FnMut(T) -> OperationResult<U>,
57    {
58        Ok(DiscoverQuery::new(
59            f(self.target)?,
60            self.pairs
61                .into_iter()
62                .map(|pair| pair.transform(&mut f))
63                .try_collect()?,
64        ))
65    }
66}
67
68impl<T> Query<T> for DiscoverQuery<T> {
69    fn score_by(&self, similarity: impl Fn(&T) -> ScoreType) -> ScoreType {
70        let rank = self.rank_by(&similarity);
71
72        let target_similarity = similarity(&self.target);
73        let sigmoid_similarity = scaled_fast_sigmoid(target_similarity);
74
75        rank as ScoreType + sigmoid_similarity
76    }
77}
78
79impl From<DiscoverQuery<VectorInternal>> for QueryVector {
80    fn from(query: DiscoverQuery<VectorInternal>) -> Self {
81        QueryVector::Discover(query)
82    }
83}
84
85#[cfg(test)]
86mod test {
87    use std::cmp::Ordering;
88
89    use crate::common::types::ScoreType;
90    use itertools::Itertools;
91    use proptest::prelude::*;
92    use rstest::rstest;
93
94    use super::*;
95
96    fn dummy_similarity(x: &isize) -> ScoreType {
97        *x as ScoreType
98    }
99
100    /// Considers each "vector" as the actual score from the similarity function by
101    /// using a dummy identity function.
102    #[rstest]
103    #[case::no_pairs(vec![], 0)]
104    #[case::closer_to_positive(vec![(10, 4)], 1)]
105    #[case::closer_to_negative(vec![(4, 10)], -1)]
106    #[case::equal_scores(vec![(11, 11)], 0)]
107    #[case::neutral_zone(vec![(10, 4), (4, 10)], 0)]
108    #[case::best_zone(vec![(10, 4), (4, 2)], 2)]
109    #[case::worst_zone(vec![(4, 10), (2, 4)], -2)]
110    #[case::many_pairs(vec![(1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (0, 4)], 4)]
111    fn context_ranking(#[case] pairs: Vec<(isize, isize)>, #[case] expected: RankType) {
112        let pairs = pairs.into_iter().map(ContextPair::from).collect();
113
114        let target = 42;
115
116        let query = DiscoverQuery::new(target, pairs);
117
118        let rank = query.rank_by(dummy_similarity);
119
120        assert_eq!(
121            rank, expected,
122            "Ranking is incorrect, expected {expected}, but got {rank}"
123        );
124    }
125
126    /// Compares the score of a query against a fixed score
127    #[rstest]
128    #[case::no_pairs(1, vec![], Ordering::Less)]
129    #[case::just_above(1, vec![(1,0),(1,0)], Ordering::Greater)]
130    #[case::just_below(-1, vec![(1,0),(1,0)], Ordering::Less)]
131    #[case::bad_target_good_context(-1000, vec![(1,0),(1,0),(1, 0)], Ordering::Greater)]
132    #[case::good_target_bad_context(1000, vec![(1,0),(0,1)], Ordering::Less)]
133    fn score_better(
134        #[case] target: isize,
135        #[case] pairs: Vec<(isize, isize)>,
136        #[case] expected: Ordering,
137    ) {
138        let fixed_score: f32 = 2.5;
139
140        let pairs = pairs.into_iter().map(ContextPair::from).collect();
141
142        let query = DiscoverQuery::new(target, pairs);
143
144        let score = query.score_by(dummy_similarity);
145
146        assert_eq!(
147            score.total_cmp(&fixed_score),
148            expected,
149            "Comparison is incorrect, expected {expected:?} for {score} against {fixed_score}"
150        );
151    }
152
153    proptest! {
154        #[test]
155        fn same_target_only_changes_rank(
156            target in -1000f32..1000f32,
157            pairs1 in prop::collection::vec((0f32..1000f32, 0.0f32..1000f32), 0..10),
158            pairs2 in prop::collection::vec((0f32..1000f32, 0.0f32..1000f32), 0..10),
159        ) {
160            let dummy_similarity = |x: &ScoreType| *x as ScoreType;
161
162            let pairs1 = pairs1.into_iter().map(ContextPair::from).collect();
163            let query1 = DiscoverQuery::new(target, pairs1);
164            let score1 = query1.score_by(dummy_similarity);
165
166            let pairs2 = pairs2.into_iter().map(ContextPair::from).collect();
167            let query2 = DiscoverQuery::new(target, pairs2);
168            let score2 = query2.score_by(dummy_similarity);
169
170            let target_part1 = score1 - score1.floor();
171            let target_part2 = score2 - score2.floor();
172
173            assert!((target_part1 - target_part2).abs() <= 1.0e-6, "Target part of score is not similar, score1: {score1}, score2: {score2}");
174        }
175
176        #[test]
177        fn same_context_only_changes_target(
178            target1 in -1000f32..1000f32,
179            target2 in -1000f32..1000f32,
180            pairs in prop::collection::vec((0f32..1000f32, 0.0f32..1000f32), 0..10),
181        )
182        {
183            let dummy_similarity = |x: &ScoreType| *x as ScoreType;
184
185            let pairs = pairs.into_iter().map(ContextPair::from).collect_vec();
186            let query1 = DiscoverQuery::new(target1, pairs.clone());
187            let score1 = query1.score_by(dummy_similarity);
188
189            let query2 = DiscoverQuery::new(target2, pairs);
190            let score2 = query2.score_by(dummy_similarity);
191
192            let context_part1 = score1.floor();
193            let context_part2 = score2.floor();
194
195            assert_eq!(context_part1, context_part2,"Context part of score isn't equal, score1: {score1}, score2: {score2}");
196        }
197    }
198}