1use std::{cmp::Ordering, collections::HashMap, sync::Arc, time::Instant};
3
4use futures::future::join_all;
5use tracing::instrument;
6
7use crate::{
8 collections::CollectionManager,
9 embeddings::EmbeddingClient,
10 error::{VectorError, VectorResult},
11 search::{
12 filters::apply_filter,
13 rerank::{apply_reranker_config, reranker_needs_vectors},
14 },
15 types::{DistanceMetric, SearchMetrics, SearchQuery, SearchResponse, SearchResult},
16};
17
18pub struct AnnSearcher {
20 pub collection_manager: Arc<CollectionManager>,
22}
23
24impl AnnSearcher {
25 pub fn new(collection_manager: Arc<CollectionManager>) -> Self {
27 Self { collection_manager }
28 }
29
30 #[instrument(skip(self, query))]
32 pub async fn search(&self, query: SearchQuery) -> VectorResult<SearchResponse> {
33 let workspace_id = self.collection_manager.config.default_workspace_id.clone();
34 self.search_in_workspace(&workspace_id, query).await
35 }
36
37 #[instrument(skip(self, query))]
39 pub async fn search_in_workspace(
40 &self,
41 workspace_id: &str,
42 query: SearchQuery,
43 ) -> VectorResult<SearchResponse> {
44 query.validate()?;
45
46 let started = Instant::now();
47 let collection = self
48 .collection_manager
49 .get_collection(workspace_id, &query.collection)
50 .await?;
51 if query.vector.len() != collection.dimensions {
52 return Err(VectorError::DimensionMismatch {
53 expected: collection.dimensions,
54 got: query.vector.len(),
55 });
56 }
57
58 let candidate_limit = query.top_k.saturating_mul(2).max(query.top_k);
59 let ef_search = query
60 .ef_search
61 .unwrap_or(self.collection_manager.config.ef_search);
62
63 let raw_candidates = {
64 let indexes = self.collection_manager.indexes.read().await;
65 let key = format!("{workspace_id}::{}", query.collection);
66 let index = indexes.get(&key).ok_or_else(|| VectorError::NotFound {
67 entity: "collection".into(),
68 id: format!("{workspace_id}/{}", query.collection),
69 })?;
70 index.search(&query.vector, candidate_limit, ef_search)?
71 };
72
73 let candidate_ids = raw_candidates
74 .iter()
75 .map(|(internal_id, _)| *internal_id)
76 .collect::<Vec<_>>();
77 let records = self
78 .collection_manager
79 .store
80 .bulk_internal_to_uuid(workspace_id, &query.collection, &candidate_ids)
81 .await?;
82 let mut records_by_id: HashMap<usize, crate::types::VectorRecord> =
83 records.into_iter().collect();
84
85 let needs_vectors =
86 query.include_vectors || reranker_needs_vectors(query.reranker.as_ref());
87 let mut results = Vec::new();
88 for (internal_id, distance) in raw_candidates {
89 let record = match records_by_id.remove(&internal_id) {
90 Some(record) => record,
91 None => continue,
92 };
93
94 if let Some(filter) = &query.filter {
95 if !apply_filter(filter, &record.metadata) {
96 continue;
97 }
98 }
99
100 let vector = if needs_vectors {
101 Some(
102 self.collection_manager
103 .read_vector_by_internal_id(workspace_id, &query.collection, internal_id)
104 .await?,
105 )
106 } else {
107 None
108 };
109
110 results.push(SearchResult {
111 id: record.id,
112 score: normalize_distance(distance, collection.distance),
113 vector,
114 metadata: if query.include_metadata {
115 record.metadata.clone()
116 } else {
117 serde_json::Value::Null
118 },
119 text: record.text.clone(),
120 created_at: record.created_at,
121 });
122 }
123
124 let post_filter_count = results.len();
125 let mut results =
126 apply_reranker_config(&query.vector, results, query.reranker.as_ref()).await?;
127 results.sort_by(|left, right| {
128 right
129 .score
130 .partial_cmp(&left.score)
131 .unwrap_or(Ordering::Equal)
132 });
133 results.truncate(query.top_k);
134
135 if !query.include_vectors {
136 for result in &mut results {
137 result.vector = None;
138 }
139 }
140
141 Ok(SearchResponse {
142 metrics: SearchMetrics {
143 query_vector_dims: query.vector.len(),
144 candidates_evaluated: candidate_ids.len(),
145 post_filter_count,
146 latency_us: started.elapsed().as_micros() as u64,
147 },
148 results,
149 })
150 }
151
152 #[instrument(skip(self, embedding_client, text))]
154 pub async fn search_by_text(
155 &self,
156 collection: &str,
157 text: &str,
158 top_k: usize,
159 embedding_client: &EmbeddingClient,
160 ) -> VectorResult<SearchResponse> {
161 let vector = embedding_client.embed_one(text).await?;
162 self.search(SearchQuery {
163 collection: collection.to_string(),
164 vector,
165 top_k,
166 filter: None,
167 include_vectors: false,
168 include_metadata: true,
169 ef_search: None,
170 reranker: None,
171 })
172 .await
173 }
174
175 #[instrument(skip(self, embedding_client, text))]
177 pub async fn search_by_text_in_workspace(
178 &self,
179 workspace_id: &str,
180 collection: &str,
181 text: &str,
182 top_k: usize,
183 embedding_client: &EmbeddingClient,
184 ) -> VectorResult<SearchResponse> {
185 let vector = embedding_client.embed_one(text).await?;
186 self.search_in_workspace(
187 workspace_id,
188 SearchQuery {
189 collection: collection.to_string(),
190 vector,
191 top_k,
192 filter: None,
193 include_vectors: false,
194 include_metadata: true,
195 ef_search: None,
196 reranker: None,
197 },
198 )
199 .await
200 }
201
202 #[instrument(skip(self, queries))]
204 pub async fn batch_search(
205 &self,
206 queries: Vec<SearchQuery>,
207 ) -> VectorResult<Vec<SearchResponse>> {
208 let handles = queries
209 .into_iter()
210 .map(|query| {
211 let searcher = AnnSearcher {
212 collection_manager: Arc::clone(&self.collection_manager),
213 };
214 tokio::task::spawn(async move { searcher.search(query).await })
215 })
216 .collect::<Vec<_>>();
217
218 let mut responses = Vec::with_capacity(handles.len());
219 for handle in join_all(handles).await {
220 let response = handle.map_err(|err| {
221 VectorError::SearchError(format!("ANN batch task failed: {err}"))
222 })??;
223 responses.push(response);
224 }
225
226 Ok(responses)
227 }
228}
229
230fn normalize_distance(distance: f32, metric: DistanceMetric) -> f32 {
231 match metric {
232 DistanceMetric::Cosine | DistanceMetric::Euclidean => {
233 (1.0 / (1.0 + distance.max(0.0))).clamp(0.0, 1.0)
234 }
235 DistanceMetric::DotProduct => (1.0 / (1.0 + distance.exp())).clamp(0.0, 1.0),
236 }
237}