qdrant_edge/segment/vector_storage/query/
context_query.rs1use std::hash::Hash;
2use std::iter::{self, Chain, Once};
3
4use crate::common::math::fast_sigmoid;
5use crate::common::types::ScoreType;
6use itertools::Itertools;
7use serde::Serialize;
8
9use super::{Query, TransformInto};
10use crate::segment::common::operation_error::OperationResult;
11use crate::segment::data_types::vectors::{QueryVector, VectorInternal};
12
13#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
14pub struct ContextPair<T> {
15 pub positive: T,
16 pub negative: T,
17}
18
19impl<T> ContextPair<T> {
20 pub fn iter(&self) -> impl Iterator<Item = &T> {
21 iter::once(&self.positive).chain(iter::once(&self.negative))
22 }
23
24 pub fn transform<F, U>(self, mut f: F) -> OperationResult<ContextPair<U>>
25 where
26 F: FnMut(T) -> OperationResult<U>,
27 {
28 Ok(ContextPair {
29 positive: f(self.positive)?,
30 negative: f(self.negative)?,
31 })
32 }
33
34 pub fn loss_by(&self, similarity: impl Fn(&T) -> ScoreType) -> ScoreType {
54 const MARGIN: ScoreType = ScoreType::EPSILON;
55
56 let positive = similarity(&self.positive);
57 let negative = similarity(&self.negative);
58
59 let difference = positive - negative - MARGIN;
60
61 fast_sigmoid(ScoreType::min(difference, 0.0))
62 }
63}
64
65impl<T> IntoIterator for ContextPair<T> {
66 type Item = T;
67
68 type IntoIter = Chain<Once<T>, Once<T>>;
69
70 fn into_iter(self) -> Self::IntoIter {
71 iter::once(self.positive).chain(iter::once(self.negative))
72 }
73}
74
75#[cfg(test)]
76impl<T> From<(T, T)> for ContextPair<T> {
77 fn from(pair: (T, T)) -> Self {
78 Self {
79 positive: pair.0,
80 negative: pair.1,
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
86pub struct ContextQuery<T> {
87 pub pairs: Vec<ContextPair<T>>,
88}
89
90impl<T> ContextQuery<T> {
91 pub fn new(pairs: Vec<ContextPair<T>>) -> Self {
92 Self { pairs }
93 }
94
95 pub fn flat_iter(&self) -> impl Iterator<Item = &T> {
96 self.pairs.iter().flat_map(|pair| pair.iter())
97 }
98}
99
100impl<T, U> TransformInto<ContextQuery<U>, T, U> for ContextQuery<T> {
101 fn transform<F>(self, mut f: F) -> OperationResult<ContextQuery<U>>
102 where
103 F: FnMut(T) -> OperationResult<U>,
104 {
105 Ok(ContextQuery::new(
106 self.pairs
107 .into_iter()
108 .map(|pair| pair.transform(&mut f))
109 .try_collect()?,
110 ))
111 }
112}
113
114impl<T> Query<T> for ContextQuery<T> {
115 fn score_by(&self, similarity: impl Fn(&T) -> ScoreType) -> ScoreType {
116 self.pairs
117 .iter()
118 .map(|pair| pair.loss_by(&similarity))
119 .sum()
120 }
121}
122
123impl<T> From<Vec<ContextPair<T>>> for ContextQuery<T> {
124 fn from(pairs: Vec<ContextPair<T>>) -> Self {
125 ContextQuery::new(pairs)
126 }
127}
128
129impl From<ContextQuery<VectorInternal>> for QueryVector {
130 fn from(query: ContextQuery<VectorInternal>) -> Self {
131 QueryVector::Context(query)
132 }
133}
134
135#[cfg(test)]
136mod test {
137 use crate::common::types::ScoreType;
138 use proptest::prelude::*;
139
140 use super::*;
141
142 fn dummy_similarity(x: &f32) -> ScoreType {
143 *x as ScoreType
144 }
145
146 fn sim() -> impl Strategy<Value = f32> {
148 (-100.0..=100.0).prop_map(|x| x as f32)
149 }
150
151 proptest! {
152 #![proptest_config(ProptestConfig::with_cases(1000))]
153
154 #[test]
156 fn loss_is_not_more_than_1_per_pair((p, n) in (sim(), sim())) {
157 let query = ContextQuery::new(vec![ContextPair::from((p, n))]);
158
159 let score = query.score_by(dummy_similarity);
160 assert!(score <= 0.0, "similarity: {score}");
161 assert!(score > -1.0, "similarity: {score}");
162 }
163 }
164}