1use qdrant_client::Qdrant;
2use qdrant_client::qdrant::{
3 Condition, CreateCollectionBuilder, Distance, Filter, PointStruct, SearchPointsBuilder,
4 UpsertPointsBuilder, VectorParamsBuilder,
5};
6use std::collections::HashMap;
7
8use super::error::VectorDbError;
9use super::model::{SearchResult, VectorPoint};
10use crate::vectordb::WriteConsistency;
11
12#[derive(Clone)]
13pub struct QdrantClient {
15 client: Qdrant,
16 url: String,
17}
18
19impl QdrantClient {
20 pub async fn new(url: &str) -> Result<Self, VectorDbError> {
22 let client =
23 Qdrant::from_url(url)
24 .build()
25 .map_err(|e| VectorDbError::ConnectionFailed {
26 url: url.to_string(),
27 message: e.to_string(),
28 })?;
29
30 Ok(Self {
31 client,
32 url: url.to_string(),
33 })
34 }
35
36 pub fn client(&self) -> &Qdrant {
38 &self.client
39 }
40
41 pub fn url(&self) -> &str {
43 &self.url
44 }
45
46 pub async fn health_check(&self) -> Result<(), VectorDbError> {
48 self.client
49 .health_check()
50 .await
51 .map_err(|e| VectorDbError::ConnectionFailed {
52 url: self.url.clone(),
53 message: e.to_string(),
54 })?;
55 Ok(())
56 }
57
58 pub async fn create_collection(
60 &self,
61 name: &str,
62 vector_size: u64,
63 ) -> Result<(), VectorDbError> {
64 let vectors_config = VectorParamsBuilder::new(vector_size, Distance::Cosine);
65
66 self.client
67 .create_collection(
68 CreateCollectionBuilder::new(name)
69 .vectors_config(vectors_config)
70 .on_disk_payload(true),
71 )
72 .await
73 .map_err(|e| VectorDbError::CreateCollectionFailed {
74 collection: name.to_string(),
75 message: e.to_string(),
76 })?;
77
78 Ok(())
79 }
80
81 pub async fn ensure_collection(
83 &self,
84 name: &str,
85 vector_size: u64,
86 ) -> Result<(), VectorDbError> {
87 let exists = self.client.collection_exists(name).await.map_err(|e| {
88 VectorDbError::CreateCollectionFailed {
89 collection: name.to_string(),
90 message: e.to_string(),
91 }
92 })?;
93
94 if !exists {
95 self.create_collection(name, vector_size).await?;
96 }
97
98 Ok(())
99 }
100
101 pub async fn collection_exists(&self, name: &str) -> Result<bool, VectorDbError> {
103 self.client.collection_exists(name).await.map_err(|e| {
104 VectorDbError::CreateCollectionFailed {
105 collection: name.to_string(),
106 message: e.to_string(),
107 }
108 })
109 }
110
111 pub async fn upsert_points(
113 &self,
114 collection: &str,
115 points: Vec<VectorPoint>,
116 consistency: WriteConsistency,
117 ) -> Result<(), VectorDbError> {
118 if points.is_empty() {
119 return Ok(());
120 }
121
122 let qdrant_points: Vec<PointStruct> = points
123 .into_iter()
124 .map(|p| {
125 let mut payload: HashMap<String, qdrant_client::qdrant::Value> = HashMap::new();
126 payload.insert("tenant_id".to_string(), (p.tenant_id as i64).into());
127 payload.insert("context_hash".to_string(), (p.context_hash as i64).into());
128 payload.insert("timestamp".to_string(), p.timestamp.into());
129 if let Some(key) = p.storage_key {
130 payload.insert("storage_key".to_string(), key.into());
131 }
132
133 PointStruct::new(p.id, p.vector, payload)
134 })
135 .collect();
136
137 self.client
138 .upsert_points(
139 UpsertPointsBuilder::new(collection, qdrant_points).wait(consistency.into()),
140 )
141 .await
142 .map_err(|e| VectorDbError::UpsertFailed {
143 collection: collection.to_string(),
144 message: e.to_string(),
145 })?;
146
147 Ok(())
148 }
149
150 pub async fn search(
152 &self,
153 collection: &str,
154 query: Vec<f32>,
155 limit: u64,
156 tenant_filter: Option<u64>,
157 ) -> Result<Vec<SearchResult>, VectorDbError> {
158 let mut search_builder =
159 SearchPointsBuilder::new(collection, query, limit).with_payload(true);
160
161 if let Some(tenant_id) = tenant_filter {
162 let filter = Filter::must([Condition::matches("tenant_id", tenant_id as i64)]);
163 search_builder = search_builder.filter(filter);
164 }
165
166 let search_result = self
167 .client
168 .search_points(search_builder)
169 .await
170 .map_err(|e| VectorDbError::SearchFailed {
171 collection: collection.to_string(),
172 message: e.to_string(),
173 })?;
174
175 let results = search_result
176 .result
177 .into_iter()
178 .filter_map(SearchResult::from_scored_point)
179 .collect();
180
181 Ok(results)
182 }
183
184 pub async fn delete_points(
186 &self,
187 collection: &str,
188 ids: Vec<u64>,
189 ) -> Result<(), VectorDbError> {
190 if ids.is_empty() {
191 return Ok(());
192 }
193
194 use qdrant_client::qdrant::{DeletePointsBuilder, PointsIdsList};
195
196 let points_selector = PointsIdsList {
197 ids: ids.into_iter().map(|id| id.into()).collect(),
198 };
199
200 self.client
201 .delete_points(
202 DeletePointsBuilder::new(collection)
203 .points(points_selector)
204 .wait(true),
205 )
206 .await
207 .map_err(|e| VectorDbError::DeleteFailed {
208 collection: collection.to_string(),
209 message: e.to_string(),
210 })?;
211
212 Ok(())
213 }
214}
215
216pub trait VectorDbClient: Send + Sync {
218 fn ensure_collection(
220 &self,
221 name: &str,
222 vector_size: u64,
223 ) -> impl std::future::Future<Output = Result<(), VectorDbError>> + Send;
224
225 fn upsert_points(
227 &self,
228 collection: &str,
229 points: Vec<VectorPoint>,
230 consistency: WriteConsistency,
231 ) -> impl std::future::Future<Output = Result<(), VectorDbError>> + Send;
232
233 fn search(
235 &self,
236 collection: &str,
237 query: Vec<f32>,
238 limit: u64,
239 tenant_filter: Option<u64>,
240 ) -> impl std::future::Future<Output = Result<Vec<SearchResult>, VectorDbError>> + Send;
241
242 fn delete_points(
244 &self,
245 collection: &str,
246 ids: Vec<u64>,
247 ) -> impl std::future::Future<Output = Result<(), VectorDbError>> + Send;
248}
249
250impl VectorDbClient for QdrantClient {
251 async fn ensure_collection(&self, name: &str, vector_size: u64) -> Result<(), VectorDbError> {
252 self.ensure_collection(name, vector_size).await
253 }
254
255 async fn upsert_points(
256 &self,
257 collection: &str,
258 points: Vec<VectorPoint>,
259 consistency: WriteConsistency,
260 ) -> Result<(), VectorDbError> {
261 self.upsert_points(collection, points, consistency).await
262 }
263
264 async fn search(
265 &self,
266 collection: &str,
267 query: Vec<f32>,
268 limit: u64,
269 tenant_filter: Option<u64>,
270 ) -> Result<Vec<SearchResult>, VectorDbError> {
271 self.search(collection, query, limit, tenant_filter).await
272 }
273
274 async fn delete_points(&self, collection: &str, ids: Vec<u64>) -> Result<(), VectorDbError> {
275 self.delete_points(collection, ids).await
276 }
277}