Skip to main content

qdrant_edge/segment/vector_storage/query/
context_query.rs

1use std::hash::Hash;
2use std::iter::{self, Chain, Once};
3
4use crate::common::math::fast_sigmoid;
5use crate::common::types::ScoreType;
6use itertools::Itertools;
7use serde::Serialize;
8
9use super::{Query, TransformInto};
10use crate::segment::common::operation_error::OperationResult;
11use crate::segment::data_types::vectors::{QueryVector, VectorInternal};
12
13#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
14pub struct ContextPair<T> {
15    pub positive: T,
16    pub negative: T,
17}
18
19impl<T> ContextPair<T> {
20    pub fn iter(&self) -> impl Iterator<Item = &T> {
21        iter::once(&self.positive).chain(iter::once(&self.negative))
22    }
23
24    pub fn transform<F, U>(self, mut f: F) -> OperationResult<ContextPair<U>>
25    where
26        F: FnMut(T) -> OperationResult<U>,
27    {
28        Ok(ContextPair {
29            positive: f(self.positive)?,
30            negative: f(self.negative)?,
31        })
32    }
33
34    /// In the first stage of discovery search, the objective is to get the best entry point
35    /// for the search. This is done by using a smooth loss function instead of hard ranking
36    /// to approach the best zone, once the best zone is reached, score will be same for all
37    /// points inside that zone.
38    /// e.g.:
39    /// ```text
40    ///                   │
41    ///                   │
42    ///                   │    +0
43    ///                   │             +0
44    ///                   │
45    ///         n         │         p
46    ///                   │
47    ///   ─►          ─►  │
48    ///  -0.4        -0.1 │   +0
49    ///                   │
50    /// ```
51    /// Simple 2D model:
52    /// <https://www.desmos.com/calculator/lbxycyh2hs>
53    pub fn loss_by(&self, similarity: impl Fn(&T) -> ScoreType) -> ScoreType {
54        const MARGIN: ScoreType = ScoreType::EPSILON;
55
56        let positive = similarity(&self.positive);
57        let negative = similarity(&self.negative);
58
59        let difference = positive - negative - MARGIN;
60
61        fast_sigmoid(ScoreType::min(difference, 0.0))
62    }
63}
64
65impl<T> IntoIterator for ContextPair<T> {
66    type Item = T;
67
68    type IntoIter = Chain<Once<T>, Once<T>>;
69
70    fn into_iter(self) -> Self::IntoIter {
71        iter::once(self.positive).chain(iter::once(self.negative))
72    }
73}
74
75#[cfg(test)]
76impl<T> From<(T, T)> for ContextPair<T> {
77    fn from(pair: (T, T)) -> Self {
78        Self {
79            positive: pair.0,
80            negative: pair.1,
81        }
82    }
83}
84
85#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
86pub struct ContextQuery<T> {
87    pub pairs: Vec<ContextPair<T>>,
88}
89
90impl<T> ContextQuery<T> {
91    pub fn new(pairs: Vec<ContextPair<T>>) -> Self {
92        Self { pairs }
93    }
94
95    pub fn flat_iter(&self) -> impl Iterator<Item = &T> {
96        self.pairs.iter().flat_map(|pair| pair.iter())
97    }
98}
99
100impl<T, U> TransformInto<ContextQuery<U>, T, U> for ContextQuery<T> {
101    fn transform<F>(self, mut f: F) -> OperationResult<ContextQuery<U>>
102    where
103        F: FnMut(T) -> OperationResult<U>,
104    {
105        Ok(ContextQuery::new(
106            self.pairs
107                .into_iter()
108                .map(|pair| pair.transform(&mut f))
109                .try_collect()?,
110        ))
111    }
112}
113
114impl<T> Query<T> for ContextQuery<T> {
115    fn score_by(&self, similarity: impl Fn(&T) -> ScoreType) -> ScoreType {
116        self.pairs
117            .iter()
118            .map(|pair| pair.loss_by(&similarity))
119            .sum()
120    }
121}
122
123impl<T> From<Vec<ContextPair<T>>> for ContextQuery<T> {
124    fn from(pairs: Vec<ContextPair<T>>) -> Self {
125        ContextQuery::new(pairs)
126    }
127}
128
129impl From<ContextQuery<VectorInternal>> for QueryVector {
130    fn from(query: ContextQuery<VectorInternal>) -> Self {
131        QueryVector::Context(query)
132    }
133}
134
135#[cfg(test)]
136mod test {
137    use crate::common::types::ScoreType;
138    use proptest::prelude::*;
139
140    use super::*;
141
142    fn dummy_similarity(x: &f32) -> ScoreType {
143        *x as ScoreType
144    }
145
146    /// Possible similarities
147    fn sim() -> impl Strategy<Value = f32> {
148        (-100.0..=100.0).prop_map(|x| x as f32)
149    }
150
151    proptest! {
152        #![proptest_config(ProptestConfig::with_cases(1000))]
153
154        /// Checks that the loss is between 0 and -1
155        #[test]
156        fn loss_is_not_more_than_1_per_pair((p, n) in (sim(), sim())) {
157            let query = ContextQuery::new(vec![ContextPair::from((p, n))]);
158
159            let score = query.score_by(dummy_similarity);
160            assert!(score <= 0.0, "similarity: {score}");
161            assert!(score > -1.0, "similarity: {score}");
162        }
163    }
164}