qdrant/
lib.rs

1#[cfg(feature = "logging")]
2#[macro_use]
3extern crate log;
4
5use anyhow::{anyhow, bail, Error};
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use serde_json::{Map, Value};
9use std::fmt::Display;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(untagged)]
13pub enum PointId {
14    Uuid(String),
15    Num(u64),
16}
17impl From<u64> for PointId {
18    fn from(num: u64) -> Self {
19        PointId::Num(num)
20    }
21}
22impl From<String> for PointId {
23    fn from(uuid: String) -> Self {
24        PointId::Uuid(uuid)
25    }
26}
27impl Display for PointId {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            PointId::Uuid(uuid) => write!(f, "{}", uuid),
31            PointId::Num(num) => write!(f, "{}", num),
32        }
33    }
34}
35
36/// The point struct.
37/// A point is a record consisting of a vector and an optional payload.
38#[derive(Debug, Serialize, Deserialize)]
39#[serde(rename_all = "camelCase")]
40pub struct Point {
41    /// Id of the point
42    pub id: PointId,
43
44    /// Vectors
45    pub vector: Vec<f32>,
46
47    /// Additional information along with vectors
48    pub payload: Option<Map<String, Value>>,
49}
50
51/// The point struct with the score returned by searching
52#[derive(Debug, Serialize, Deserialize)]
53#[serde(rename_all = "camelCase")]
54pub struct ScoredPoint {
55    /// Id of the point
56    pub id: PointId,
57
58    /// Vectors
59    pub vector: Option<Vec<f32>>,
60
61    /// Additional information along with vectors
62    pub payload: Option<Map<String, Value>>,
63
64    /// Points vector distance to the query vector
65    pub score: f32,
66}
67
68pub struct Qdrant {
69    pub url_base: String,
70    api_key: Option<String>,
71}
72
73impl Qdrant {
74    pub fn new_with_url(url_base_: String) -> Qdrant {
75        Qdrant {
76            url_base: url_base_,
77            api_key: None,
78        }
79    }
80
81    pub fn new() -> Qdrant {
82        Qdrant::new_with_url("http://localhost:6333".to_string())
83    }
84
85    pub fn set_api_key(&mut self, api_key: impl Into<String>) {
86        self.api_key = Some(api_key.into());
87    }
88}
89
90impl Default for Qdrant {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96/// Shortcut functions
97impl Qdrant {
98    /// Shortcut functions
99    pub async fn collection_info(&self, collection_name: &str) -> u64 {
100        #[cfg(feature = "logging")]
101        info!(target: "stdout", "get collection info: '{}'", collection_name);
102
103        let v = self.collection_info_api(collection_name).await.unwrap();
104        v.get("result")
105            .unwrap()
106            .get("points_count")
107            .unwrap()
108            .as_u64()
109            .unwrap()
110    }
111
112    pub async fn create_collection(&self, collection_name: &str, size: u32) -> Result<(), Error> {
113        #[cfg(feature = "logging")]
114        info!(target: "stdout", "create collection '{}'", collection_name);
115
116        match self.collection_exists(collection_name).await {
117            Ok(false) => (),
118            Ok(true) => {
119                let err_msg = format!("Collection '{}' already exists", collection_name);
120
121                #[cfg(feature = "logging")]
122                error!(target: "stdout", "{}", &err_msg);
123
124                bail!(err_msg);
125            }
126            Err(e) => {
127                #[cfg(feature = "logging")]
128                error!(target: "stdout", "{}", e);
129
130                bail!("{}", e);
131            }
132        }
133
134        let params = json!({
135            "vectors": {
136                "size": size,
137                "distance": "Cosine",
138                "on_disk": true,
139            }
140        });
141        if !self.create_collection_api(collection_name, &params).await? {
142            bail!("Failed to create collection '{}'", collection_name);
143        }
144        Ok(())
145    }
146
147    pub async fn list_collections(&self) -> Result<Vec<String>, Error> {
148        #[cfg(feature = "logging")]
149        info!(target: "stdout", "list collections");
150
151        self.list_collections_api().await
152    }
153
154    pub async fn collection_exists(&self, collection_name: &str) -> Result<bool, Error> {
155        #[cfg(feature = "logging")]
156        info!(target: "stdout", "check collection existence: {}", collection_name);
157
158        let collection_names = self.list_collections().await?;
159
160        Ok(collection_names.contains(&collection_name.to_string()))
161    }
162
163    pub async fn delete_collection(&self, collection_name: &str) -> Result<(), Error> {
164        #[cfg(feature = "logging")]
165        info!(target: "stdout", "delete collection '{}'", collection_name);
166
167        match self.collection_exists(collection_name).await {
168            Ok(true) => (),
169            Ok(false) => {
170                let err_msg = format!("Not found collection '{}'", collection_name);
171
172                #[cfg(feature = "logging")]
173                error!(target: "stdout", "{}", &err_msg);
174
175                bail!(err_msg);
176            }
177            Err(e) => {
178                #[cfg(feature = "logging")]
179                error!(target: "stdout", "{}", e);
180
181                bail!("{}", e);
182            }
183        }
184
185        if !self.delete_collection_api(collection_name).await? {
186            bail!("Failed to delete collection '{}'", collection_name);
187        }
188        Ok(())
189    }
190
191    pub async fn upsert_points(
192        &self,
193        collection_name: &str,
194        points: Vec<Point>,
195    ) -> Result<(), Error> {
196        #[cfg(feature = "logging")]
197        info!(target: "stdout", "upsert {} points to collection '{}'", points.len(), collection_name);
198
199        let params = json!({
200            "points": points,
201        });
202        self.upsert_points_api(collection_name, &params).await
203    }
204
205    pub async fn search_points(
206        &self,
207        collection_name: &str,
208        vector: Vec<f32>,
209        limit: u64,
210        score_threshold: Option<f32>,
211    ) -> Result<Vec<ScoredPoint>, Error> {
212        #[cfg(feature = "logging")]
213        info!(target: "stdout", "search points in collection '{}'", collection_name);
214
215        let score_threshold = score_threshold.unwrap_or(0.0);
216
217        let params = json!({
218            "vector": vector,
219            "limit": limit,
220            "with_payload": true,
221            "with_vector": true,
222            "score_threshold": score_threshold,
223        });
224
225        match self.search_points_api(collection_name, &params).await {
226            Ok(v) => {
227                match v.get("result") {
228                    Some(v) => match v.as_array() {
229                        Some(rs) => {
230                            let mut sps: Vec<ScoredPoint> = Vec::<ScoredPoint>::new();
231                            for r in rs {
232                                let sp: ScoredPoint = serde_json::from_value(r.clone()).unwrap();
233                                sps.push(sp);
234                            }
235                            Ok(sps)
236                        }
237                        None => {
238                            bail!("[qdrant] The value corresponding to the 'result' key is not an array.")
239                        }
240                    },
241                    None => {
242                        let warn_msg = "[qdrant] The given key 'result' does not exist.";
243
244                        #[cfg(feature = "logging")]
245                        warn!(target: "stdout", "{}", warn_msg);
246
247                        Ok(vec![])
248                    }
249                }
250            }
251            Err(e) => {
252                let warn_msg = format!("[qdrant] Failed to search points: {}", e);
253
254                #[cfg(feature = "logging")]
255                warn!(target: "stdout", "{}", warn_msg);
256
257                Ok(vec![])
258            }
259        }
260    }
261
262    pub async fn get_points(&self, collection_name: &str, ids: &[PointId]) -> Vec<Point> {
263        #[cfg(feature = "logging")]
264        info!(target: "stdout", "get points from collection '{}'", collection_name);
265
266        let params = json!({
267            "ids": ids,
268            "with_payload": true,
269            "with_vector": true,
270        });
271
272        let v = self.get_points_api(collection_name, &params).await.unwrap();
273        let rs: &Vec<Value> = v.get("result").unwrap().as_array().unwrap();
274        let mut ps: Vec<Point> = Vec::<Point>::new();
275        for r in rs {
276            let p: Point = serde_json::from_value(r.clone()).unwrap();
277            ps.push(p);
278        }
279        ps
280    }
281
282    pub async fn get_point(&self, collection_name: &str, id: &PointId) -> Point {
283        #[cfg(feature = "logging")]
284        info!(target: "stdout", "get point from collection '{}' with id {}", collection_name, id);
285
286        let v = self.get_point_api(collection_name, id).await.unwrap();
287        let r = v.get("result").unwrap();
288        serde_json::from_value(r.clone()).unwrap()
289    }
290
291    pub async fn delete_points(&self, collection_name: &str, ids: &[PointId]) -> Result<(), Error> {
292        #[cfg(feature = "logging")]
293        info!(target: "stdout", "delete points from collection '{}'", collection_name);
294
295        let params = json!({
296            "points": ids,
297        });
298        self.delete_points_api(collection_name, &params).await
299    }
300
301    /// REST API functions
302    pub async fn collection_info_api(&self, collection_name: &str) -> Result<Value, Error> {
303        let url = format!("{}/collections/{}", self.url_base, collection_name,);
304
305        let client = reqwest::Client::new();
306
307        let ci = match &self.api_key {
308            Some(api_key) => {
309                client
310                    .get(&url)
311                    .header("api-key", api_key)
312                    .header("Content-Type", "application/json")
313                    .send()
314                    .await?
315                    .json()
316                    .await?
317            }
318            None => {
319                client
320                    .get(&url)
321                    .header("Content-Type", "application/json")
322                    .send()
323                    .await?
324                    .json()
325                    .await?
326            }
327        };
328
329        Ok(ci)
330    }
331
332    pub async fn create_collection_api(
333        &self,
334        collection_name: &str,
335        params: &Value,
336    ) -> Result<bool, Error> {
337        let url = format!("{}/collections/{}", self.url_base, collection_name,);
338
339        let body = serde_json::to_vec(params).unwrap_or_default();
340        let client = reqwest::Client::new();
341        let res = match &self.api_key {
342            Some(api_key) => {
343                client
344                    .put(&url)
345                    .header("api-key", api_key)
346                    .header("Content-Type", "application/json")
347                    .body(body)
348                    .send()
349                    .await?
350            }
351            None => {
352                client
353                    .put(&url)
354                    .header("Content-Type", "application/json")
355                    .body(body)
356                    .send()
357                    .await?
358            }
359        };
360
361        match res.status().is_success() {
362            true => {
363                // get response body as json
364                let json = res.json::<Value>().await?;
365                let sucess = json.get("result").unwrap().as_bool().unwrap();
366                Ok(sucess)
367            }
368            false => Err(anyhow!(
369                "[qdrant] Failed to create collection: {}",
370                collection_name
371            )),
372        }
373    }
374
375    pub async fn list_collections_api(&self) -> Result<Vec<String>, Error> {
376        let url = format!("{}/collections", self.url_base);
377        let client = reqwest::Client::new();
378        let result = match &self.api_key {
379            Some(api_key) => {
380                client
381                    .get(&url)
382                    .header("api-key", api_key)
383                    .header("Content-Type", "application/json")
384                    .send()
385                    .await
386            }
387            None => {
388                client
389                    .get(&url)
390                    .header("Content-Type", "application/json")
391                    .send()
392                    .await
393            }
394        };
395
396        let response = match result {
397            Ok(response) => response,
398            Err(e) => {
399                #[cfg(feature = "logging")]
400                error!(target: "stdout", "{}", e);
401
402                bail!("{}", e);
403            }
404        };
405
406        match response.status().is_success() {
407            true => match response.json::<Value>().await {
408                Ok(json) => match json.get("result") {
409                    Some(result) => match result.get("collections") {
410                        Some(collections) => match collections.as_array() {
411                            Some(collections) => {
412                                let mut collection_names: Vec<String> = Vec::<String>::new();
413
414                                for collection in collections {
415                                    let name = collection.get("name").unwrap().as_str().unwrap();
416                                    collection_names.push(name.to_string());
417                                }
418
419                                Ok(collection_names)
420                            },
421                            None => bail!("[qdrant] The value corresponding to the 'collections' key is not an array."),
422                        },
423                        None => bail!("[qdrant] The given key 'collections' does not exist."),
424                    },
425                    None => bail!("[qdrant] The given key 'result' does not exist."),
426                },
427                Err(e) => {
428                    #[cfg(feature = "logging")]
429                    error!(target: "stdout", "{}", e);
430
431                    bail!("{}", e);
432                }
433            }
434            false => bail!("[qdrant] Failed to list collections"),
435        }
436    }
437
438    pub async fn collection_exists_api(&self, collection_name: &str) -> Result<bool, Error> {
439        #[cfg(feature = "logging")]
440        info!(target: "stdout", "check collection existence: {}", collection_name);
441
442        let url = format!("{}/collections/{}/exists", self.url_base, collection_name,);
443        let client = reqwest::Client::new();
444
445        #[cfg(feature = "logging")]
446        info!(target: "stdout", "check collection existence: {}", url);
447
448        let result = match &self.api_key {
449            Some(api_key) => {
450                client
451                    .get(&url)
452                    .header("api-key", api_key)
453                    .header("Content-Type", "application/json")
454                    .send()
455                    .await
456            }
457            None => {
458                client
459                    .get(&url)
460                    .header("Content-Type", "application/json")
461                    .send()
462                    .await
463            }
464        };
465
466        #[cfg(feature = "logging")]
467        info!(target: "stdout", "result: {:?}", result);
468
469        let response = match result {
470            Ok(response) => response,
471            Err(e) => {
472                #[cfg(feature = "logging")]
473                error!(target: "stdout", "{}", e);
474
475                bail!("{}", e);
476            }
477        };
478
479        let json = match response.json::<Value>().await {
480            Ok(json) => json,
481            Err(e) => {
482                #[cfg(feature = "logging")]
483                error!(target: "stdout", "{}", e);
484
485                bail!("{}", e);
486            }
487        };
488
489        #[cfg(feature = "logging")]
490        info!(target: "stdout", "json: {:?}", json);
491
492        match json.get("result") {
493            Some(result) => {
494                let exists = result.get("exists").unwrap().as_bool().unwrap();
495                Ok(exists)
496            }
497            None => Err(anyhow!("[qdrant] Failed to check collection existence")),
498        }
499
500        // match res.status().is_success() {
501        //     true => {
502        //         // get response body as json
503        //         let json = res.json::<Value>().await?;
504        //         let exists = json
505        //             .get("result")
506        //             .unwrap()
507        //             .get("exists")
508        //             .unwrap()
509        //             .as_bool()
510        //             .unwrap();
511        //         Ok(exists)
512        //     }
513        //     false => Err(anyhow!("[qdrant] Failed to check collection existence")),
514        // }
515    }
516
517    pub async fn delete_collection_api(&self, collection_name: &str) -> Result<bool, Error> {
518        let url = format!("{}/collections/{}", self.url_base, collection_name,);
519
520        let client = reqwest::Client::new();
521
522        let res = match &self.api_key {
523            Some(api_key) => {
524                client
525                    .delete(&url)
526                    .header("api-key", api_key)
527                    .header("Content-Type", "application/json")
528                    .send()
529                    .await?
530            }
531            None => {
532                client
533                    .delete(&url)
534                    .header("Content-Type", "application/json")
535                    .send()
536                    .await?
537            }
538        };
539
540        match res.status().is_success() {
541            true => {
542                // get response body as json
543                let json = res.json::<Value>().await?;
544                let sucess = json.get("result").unwrap().as_bool().unwrap();
545                Ok(sucess)
546            }
547            false => Err(anyhow!(
548                "[qdrant] Failed to delete collection: {}",
549                collection_name
550            )),
551        }
552    }
553
554    pub async fn upsert_points_api(
555        &self,
556        collection_name: &str,
557        params: &Value,
558    ) -> Result<(), Error> {
559        let url = format!(
560            "{}/collections/{}/points?wait=true",
561            self.url_base, collection_name,
562        );
563
564        let body = serde_json::to_vec(params).unwrap_or_default();
565        let client = reqwest::Client::new();
566        let res = match &self.api_key {
567            Some(api_key) => {
568                client
569                    .put(&url)
570                    .header("api-key", api_key)
571                    .header("Content-Type", "application/json")
572                    .body(body)
573                    .send()
574                    .await?
575            }
576            None => {
577                client
578                    .put(&url)
579                    .header("Content-Type", "application/json")
580                    .body(body)
581                    .send()
582                    .await?
583            }
584        };
585
586        if res.status().is_success() {
587            let v = res.json::<Value>().await?;
588            let status = v.get("status").unwrap().as_str().unwrap();
589            if status == "ok" {
590                Ok(())
591            } else {
592                Err(anyhow!(
593                    "[qdrant] Failed to upsert points. Status = {}",
594                    status
595                ))
596            }
597        } else {
598            Err(anyhow!(
599                "[qdrant] Failed to upsert points: {}",
600                res.status().as_str()
601            ))
602        }
603    }
604
605    pub async fn search_points_api(
606        &self,
607        collection_name: &str,
608        params: &Value,
609    ) -> Result<Value, Error> {
610        let url = format!(
611            "{}/collections/{}/points/search",
612            self.url_base, collection_name,
613        );
614
615        let body = serde_json::to_vec(params).unwrap_or_default();
616        let client = reqwest::Client::new();
617        let response = match &self.api_key {
618            Some(api_key) => {
619                client
620                    .post(&url)
621                    .header("api-key", api_key)
622                    .header("Content-Type", "application/json")
623                    .body(body)
624                    .send()
625                    .await?
626            }
627            None => {
628                client
629                    .post(&url)
630                    .header("Content-Type", "application/json")
631                    .body(body)
632                    .send()
633                    .await?
634            }
635        };
636
637        let status_code = response.status();
638        match status_code.is_success() {
639            true => {
640                let json = response.json().await?;
641                Ok(json)
642            }
643            false => {
644                let status = status_code.as_str();
645                Err(anyhow!("[qdrant] Failed to search points: {}", status))
646            }
647        }
648    }
649
650    pub async fn get_points_api(
651        &self,
652        collection_name: &str,
653        params: &Value,
654    ) -> Result<Value, Error> {
655        let url = format!("{}/collections/{}/points", self.url_base, collection_name,);
656
657        let body = serde_json::to_vec(params).unwrap_or_default();
658        let client = reqwest::Client::new();
659
660        let json = match &self.api_key {
661            Some(api_key) => {
662                client
663                    .post(&url)
664                    .header("api-key", api_key)
665                    .header("Content-Type", "application/json")
666                    .body(body)
667                    .send()
668                    .await?
669                    .json()
670                    .await?
671            }
672            None => {
673                client
674                    .post(&url)
675                    .header("Content-Type", "application/json")
676                    .body(body)
677                    .send()
678                    .await?
679                    .json()
680                    .await?
681            }
682        };
683
684        Ok(json)
685    }
686
687    pub async fn get_point_api(&self, collection_name: &str, id: &PointId) -> Result<Value, Error> {
688        let url = format!(
689            "{}/collections/{}/points/{}",
690            self.url_base, collection_name, id,
691        );
692
693        let client = reqwest::Client::new();
694
695        let json = match &self.api_key {
696            Some(api_key) => {
697                client
698                    .get(&url)
699                    .header("api-key", api_key)
700                    .header("Content-Type", "application/json")
701                    .send()
702                    .await?
703                    .json()
704                    .await?
705            }
706            None => {
707                client
708                    .get(&url)
709                    .header("Content-Type", "application/json")
710                    .send()
711                    .await?
712                    .json()
713                    .await?
714            }
715        };
716
717        Ok(json)
718    }
719
720    pub async fn delete_points_api(
721        &self,
722        collection_name: &str,
723        params: &Value,
724    ) -> Result<(), Error> {
725        let url = format!(
726            "{}/collections/{}/points/delete?wait=true",
727            self.url_base, collection_name,
728        );
729
730        let body = serde_json::to_vec(params).unwrap_or_default();
731        let client = reqwest::Client::new();
732
733        let res = match &self.api_key {
734            Some(api_key) => {
735                client
736                    .post(&url)
737                    .header("api-key", api_key)
738                    .header("Content-Type", "application/json")
739                    .body(body)
740                    .send()
741                    .await?
742            }
743            None => {
744                client
745                    .post(&url)
746                    .header("Content-Type", "application/json")
747                    .body(body)
748                    .send()
749                    .await?
750            }
751        };
752
753        if res.status().is_success() {
754            Ok(())
755        } else {
756            Err(anyhow!(
757                "[qdrant] Failed to delete points: {}",
758                res.status().as_str()
759            ))
760        }
761    }
762}