qdrant_edge/segment/vector_storage/query/
feedback_query.rs1use std::hash::Hash;
2
3use crate::common::types::ScoreType;
4use itertools::Itertools;
5use ordered_float::OrderedFloat;
6use serde::Serialize;
7
8use super::{Query, TransformInto};
9use crate::segment::common::operation_error::OperationResult;
10
11#[derive(Clone, Debug, Serialize, Hash, PartialEq)]
12pub struct FeedbackItem<T> {
13 pub vector: T,
14 pub score: OrderedFloat<ScoreType>,
15}
16
17impl<T> FeedbackItem<T> {
18 pub fn transform<F, U>(self, mut f: F) -> OperationResult<FeedbackItem<U>>
19 where
20 F: FnMut(T) -> OperationResult<U>,
21 {
22 Ok(FeedbackItem {
23 vector: f(self.vector)?,
24 score: self.score,
25 })
26 }
27}
28
29#[derive(Clone, Debug, Serialize, Hash, PartialEq)]
33pub struct NaiveFeedbackQuery<T> {
34 pub target: T,
36
37 pub feedback: Vec<FeedbackItem<T>>,
39
40 pub coefficients: NaiveFeedbackCoefficients,
42}
43
44impl<T: Clone> NaiveFeedbackQuery<T> {
45 pub fn into_query(self) -> FeedbackQuery<T> {
46 FeedbackQuery::new(self.target, self.feedback, self.coefficients)
47 }
48}
49
50impl<T> NaiveFeedbackQuery<T> {
51 pub fn flat_iter(&self) -> impl Iterator<Item = &T> {
52 self.feedback
53 .iter()
54 .map(|item| &item.vector)
55 .chain(std::iter::once(&self.target))
56 }
57}
58
59impl<T, U> TransformInto<NaiveFeedbackQuery<U>, T, U> for NaiveFeedbackQuery<T> {
60 fn transform<F>(self, mut f: F) -> OperationResult<NaiveFeedbackQuery<U>>
61 where
62 F: FnMut(T) -> OperationResult<U>,
63 {
64 let Self {
65 target,
66 feedback,
67 coefficients,
68 } = self;
69 Ok(NaiveFeedbackQuery {
70 target: f(target)?,
71 feedback: feedback
72 .into_iter()
73 .map(|item| item.transform(&mut f))
74 .try_collect()?,
75 coefficients,
76 })
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
81pub struct ContextPair<T> {
82 pub positive: T,
84 pub negative: T,
86 pub partial_computation: OrderedFloat<f32>,
88}
89
90impl<T> ContextPair<T> {
91 pub fn transform<F, U>(self, mut f: F) -> OperationResult<ContextPair<U>>
92 where
93 F: FnMut(T) -> OperationResult<U>,
94 {
95 Ok(ContextPair {
96 positive: f(self.positive)?,
97 negative: f(self.negative)?,
98 partial_computation: self.partial_computation,
99 })
100 }
101}
102
103#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize)]
105pub struct NaiveFeedbackCoefficients {
106 pub a: OrderedFloat<f32>,
108 pub b: OrderedFloat<f32>,
110 pub c: OrderedFloat<f32>,
112}
113
114impl NaiveFeedbackCoefficients {
115 fn extract_context_pairs<TVector: Clone>(
119 &self,
120 feedback: Vec<FeedbackItem<TVector>>,
121 margin: f32,
122 ) -> Vec<ContextPair<TVector>> {
123 if feedback.len() < 2 {
124 return Vec::new();
126 }
127
128 let mut feedback_pairs = Vec::new();
129 for permutation in feedback.iter().permutations(2) {
130 let (positive, negative) = (permutation[0], permutation[1]);
131 let confidence = positive.score - negative.score;
132
133 if confidence.0 <= margin {
134 continue;
135 }
136
137 let partial_computation = confidence.powf(self.b.0) * self.c.0;
138 feedback_pairs.push(ContextPair {
139 positive: positive.vector.clone(),
140 negative: negative.vector.clone(),
141 partial_computation: partial_computation.into(),
142 });
143 }
144 feedback_pairs
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Serialize, Hash)]
150pub struct FeedbackQuery<TVector> {
151 target: TVector,
153
154 context_pairs: Vec<ContextPair<TVector>>,
156
157 coefficients: NaiveFeedbackCoefficients,
159}
160
161impl<TVector: Clone> FeedbackQuery<TVector> {
162 pub fn new(
163 target: TVector,
164 feedback: Vec<FeedbackItem<TVector>>,
165 coefficients: NaiveFeedbackCoefficients,
166 ) -> Self {
167 let context_pairs = coefficients.extract_context_pairs(feedback, 0.0);
168
169 Self {
170 target,
171 context_pairs,
172 coefficients,
173 }
174 }
175}
176
177impl<T, U> TransformInto<FeedbackQuery<U>, T, U> for FeedbackQuery<T> {
178 fn transform<F>(self, mut f: F) -> OperationResult<FeedbackQuery<U>>
179 where
180 F: FnMut(T) -> OperationResult<U>,
181 {
182 let Self {
183 target,
184 context_pairs,
185 coefficients,
186 } = self;
187 Ok(FeedbackQuery {
188 target: f(target)?,
189 context_pairs: context_pairs
190 .into_iter()
191 .map(|pair| pair.transform(&mut f))
192 .try_collect()?,
193 coefficients,
194 })
195 }
196}
197
198impl<T> Query<T> for FeedbackQuery<T> {
199 fn score_by(&self, similarity: impl Fn(&T) -> ScoreType) -> ScoreType {
208 let Self {
209 target,
210 context_pairs,
211 coefficients,
212 } = self;
213
214 let mut score = coefficients.a.0 * similarity(target);
215
216 for pair in context_pairs {
217 let ContextPair {
218 positive,
219 negative,
220 partial_computation,
221 } = pair;
222
223 let delta = similarity(positive) - similarity(negative);
224
225 score += partial_computation.0 * delta;
226 }
227
228 score
229 }
230}