Skip to main content

claw_vector/search/
ann.rs

1// search/ann.rs — approximate nearest-neighbour search orchestration.
2use std::{cmp::Ordering, collections::HashMap, sync::Arc, time::Instant};
3
4use futures::future::join_all;
5use tracing::instrument;
6
7use crate::{
8    collections::CollectionManager,
9    embeddings::EmbeddingClient,
10    error::{VectorError, VectorResult},
11    search::{
12        filters::apply_filter,
13        rerank::{apply_reranker_config, reranker_needs_vectors},
14    },
15    types::{DistanceMetric, SearchMetrics, SearchQuery, SearchResponse, SearchResult},
16};
17
18/// Core ANN search service.
19pub struct AnnSearcher {
20    /// Collection and persistence coordinator used during search.
21    pub collection_manager: Arc<CollectionManager>,
22}
23
24impl AnnSearcher {
25    /// Create a new ANN searcher.
26    pub fn new(collection_manager: Arc<CollectionManager>) -> Self {
27        Self { collection_manager }
28    }
29
30    /// Execute a nearest-neighbour search and return filtered, ranked results with metrics.
31    #[instrument(skip(self, query))]
32    pub async fn search(&self, query: SearchQuery) -> VectorResult<SearchResponse> {
33        let workspace_id = self.collection_manager.config.default_workspace_id.clone();
34        self.search_in_workspace(&workspace_id, query).await
35    }
36
37    /// Execute a nearest-neighbour search scoped to a workspace.
38    #[instrument(skip(self, query))]
39    pub async fn search_in_workspace(
40        &self,
41        workspace_id: &str,
42        query: SearchQuery,
43    ) -> VectorResult<SearchResponse> {
44        query.validate()?;
45
46        let started = Instant::now();
47        let collection = self
48            .collection_manager
49            .get_collection(workspace_id, &query.collection)
50            .await?;
51        if query.vector.len() != collection.dimensions {
52            return Err(VectorError::DimensionMismatch {
53                expected: collection.dimensions,
54                got: query.vector.len(),
55            });
56        }
57
58        let candidate_limit = query.top_k.saturating_mul(2).max(query.top_k);
59        let ef_search = query
60            .ef_search
61            .unwrap_or(self.collection_manager.config.ef_search);
62
63        let raw_candidates = {
64            let indexes = self.collection_manager.indexes.read().await;
65            let key = format!("{workspace_id}::{}", query.collection);
66            let index = indexes.get(&key).ok_or_else(|| VectorError::NotFound {
67                entity: "collection".into(),
68                id: format!("{workspace_id}/{}", query.collection),
69            })?;
70            index.search(&query.vector, candidate_limit, ef_search)?
71        };
72
73        let candidate_ids = raw_candidates
74            .iter()
75            .map(|(internal_id, _)| *internal_id)
76            .collect::<Vec<_>>();
77        let records = self
78            .collection_manager
79            .store
80            .bulk_internal_to_uuid(workspace_id, &query.collection, &candidate_ids)
81            .await?;
82        let mut records_by_id: HashMap<usize, crate::types::VectorRecord> =
83            records.into_iter().collect();
84
85        let needs_vectors =
86            query.include_vectors || reranker_needs_vectors(query.reranker.as_ref());
87        let mut results = Vec::new();
88        for (internal_id, distance) in raw_candidates {
89            let record = match records_by_id.remove(&internal_id) {
90                Some(record) => record,
91                None => continue,
92            };
93
94            if let Some(filter) = &query.filter {
95                if !apply_filter(filter, &record.metadata) {
96                    continue;
97                }
98            }
99
100            let vector = if needs_vectors {
101                Some(
102                    self.collection_manager
103                        .read_vector_by_internal_id(workspace_id, &query.collection, internal_id)
104                        .await?,
105                )
106            } else {
107                None
108            };
109
110            results.push(SearchResult {
111                id: record.id,
112                score: normalize_distance(distance, collection.distance),
113                vector,
114                metadata: if query.include_metadata {
115                    record.metadata.clone()
116                } else {
117                    serde_json::Value::Null
118                },
119                text: record.text.clone(),
120                created_at: record.created_at,
121            });
122        }
123
124        let post_filter_count = results.len();
125        let mut results =
126            apply_reranker_config(&query.vector, results, query.reranker.as_ref()).await?;
127        results.sort_by(|left, right| {
128            right
129                .score
130                .partial_cmp(&left.score)
131                .unwrap_or(Ordering::Equal)
132        });
133        results.truncate(query.top_k);
134
135        if !query.include_vectors {
136            for result in &mut results {
137                result.vector = None;
138            }
139        }
140
141        Ok(SearchResponse {
142            metrics: SearchMetrics {
143                query_vector_dims: query.vector.len(),
144                candidates_evaluated: candidate_ids.len(),
145                post_filter_count,
146                latency_us: started.elapsed().as_micros() as u64,
147            },
148            results,
149        })
150    }
151
152    /// Embed free-form text and execute ANN search.
153    #[instrument(skip(self, embedding_client, text))]
154    pub async fn search_by_text(
155        &self,
156        collection: &str,
157        text: &str,
158        top_k: usize,
159        embedding_client: &EmbeddingClient,
160    ) -> VectorResult<SearchResponse> {
161        let vector = embedding_client.embed_one(text).await?;
162        self.search(SearchQuery {
163            collection: collection.to_string(),
164            vector,
165            top_k,
166            filter: None,
167            include_vectors: false,
168            include_metadata: true,
169            ef_search: None,
170            reranker: None,
171        })
172        .await
173    }
174
175    /// Embed free-form text and execute workspace-scoped ANN search.
176    #[instrument(skip(self, embedding_client, text))]
177    pub async fn search_by_text_in_workspace(
178        &self,
179        workspace_id: &str,
180        collection: &str,
181        text: &str,
182        top_k: usize,
183        embedding_client: &EmbeddingClient,
184    ) -> VectorResult<SearchResponse> {
185        let vector = embedding_client.embed_one(text).await?;
186        self.search_in_workspace(
187            workspace_id,
188            SearchQuery {
189                collection: collection.to_string(),
190                vector,
191                top_k,
192                filter: None,
193                include_vectors: false,
194                include_metadata: true,
195                ef_search: None,
196                reranker: None,
197            },
198        )
199        .await
200    }
201
202    /// Execute multiple ANN queries concurrently.
203    #[instrument(skip(self, queries))]
204    pub async fn batch_search(
205        &self,
206        queries: Vec<SearchQuery>,
207    ) -> VectorResult<Vec<SearchResponse>> {
208        let handles = queries
209            .into_iter()
210            .map(|query| {
211                let searcher = AnnSearcher {
212                    collection_manager: Arc::clone(&self.collection_manager),
213                };
214                tokio::task::spawn(async move { searcher.search(query).await })
215            })
216            .collect::<Vec<_>>();
217
218        let mut responses = Vec::with_capacity(handles.len());
219        for handle in join_all(handles).await {
220            let response = handle.map_err(|err| {
221                VectorError::SearchError(format!("ANN batch task failed: {err}"))
222            })??;
223            responses.push(response);
224        }
225
226        Ok(responses)
227    }
228}
229
230fn normalize_distance(distance: f32, metric: DistanceMetric) -> f32 {
231    match metric {
232        DistanceMetric::Cosine | DistanceMetric::Euclidean => {
233            (1.0 / (1.0 + distance.max(0.0))).clamp(0.0, 1.0)
234        }
235        DistanceMetric::DotProduct => (1.0 / (1.0 + distance.exp())).clamp(0.0, 1.0),
236    }
237}