arc_vector_rust/
lib.rs

1//! The ArcVector Vector Database client
2//!
3//! This library uses GRPC to connect to the ArcVector server and allows you to
4//! access most if not all features. If you find a missing feature, please open
5//! an [issue](https://github.com/arc_vector/rust-client/issues/new).
6//!
7//! If you use this library, you'll likely want to import the usual types and
8//! functions:
9//! ```
10//!#[allow(unused_import)]
11//! use arc_vector_rust::prelude::*;
12//! ```
13//!
14//! To work with a ArcVector database, you'll first need to connect by creating a
15//! [`ArcVectorClient`](crate::client::ArcVectorClient):
16//! ```
17//!# use arc_vector_rust::prelude::*;
18//!# fn establish_connection(url: &str) -> anyhow::Result<ArcVectorClient> {
19//! let mut config = ArcVectorClientConfig::from_url(url);
20//! config.api_key = std::env::var("ARC_VECTOR_API_KEY").ok();
21//! ArcVectorClient::new(Some(config))
22//!# }
23//! ```
24//!
25//! ArcVector works with *Collections* of *Points*. To add vector data, you first
26//! create a collection:
27//!
28//! ```
29//!# use arc_vector_rust::prelude::*;
30//! use arc_vector_rust::arc_vector::{VectorParams, VectorsConfig};
31//! use arc_vector_rust::arc_vector::vectors_config::Config;
32//!# async fn create_collection(arc_vector_client: &ArcVectorClient)
33//!# -> Result<(), Box<dyn std::error::Error>> {
34//! let response = arc_vector_client
35//!     .create_collection(&CreateCollection {
36//!         collection_name: "my_collection".into(),
37//!         vectors_config: Some(VectorsConfig {
38//!             config: Some(Config::Params(VectorParams {
39//!                 size: 512,
40//!                 distance: Distance::Cosine as i32,
41//!                 ..Default::default()
42//!             })),
43//!         }),
44//!         ..Default::default()
45//!     })
46//!     .await?;
47//!# Ok(())
48//!# }
49//! ```
50//! The most interesting parts are the `collection_name` and the
51//! `vectors_config.size` (the length of vectors to store) and `distance`
52//! (which is the [`Distance`](crate::arc_vector::Distance) measure to gauge
53//! similarity for the nearest neighbors search).
54//!
55//! Now we have a collection, we can insert (or rather upsert) points.
56//! Points have an id, one or more vectors and a payload.
57//! We can usually do that in bulk, but for this example, we'll add a
58//! single point:
59//! ```
60//!# use arc_vector_rust::{prelude::*, arc_vector::PointId};
61//!# async fn do_upsert(arc_vector_client: &ArcVectorClient)
62//!# -> Result<(), Box<dyn std::error::Error>> {
63//! let point = PointStruct {
64//!     id: Some(PointId::from(42)), // unique u64 or String
65//!     vectors: Some(vec![0.0_f32; 512].into()),
66//!     payload: std::collections::HashMap::from([
67//!         ("great".into(), Value::from(true)),
68//!         ("level".into(), Value::from(9000)),
69//!         ("text".into(), Value::from("Hi ArcVector!")),
70//!         ("list".into(), Value::from(vec![1.234, 0.815])),
71//!     ]),
72//! };
73//!
74//! let response = arc_vector_client
75//!     .upsert_points("my_collection", vec![point], None)
76//!     .await?;
77//!# Ok(())
78//!# }
79//! ```
80//!
81//! Finally, we can retrieve points in various ways, the canonical one being
82//! a plain similarity search:
83//! ```
84//!# use arc_vector_rust::prelude::*;
85//!# async fn search(arc_vector_client: &ArcVectorClient)
86//!# -> Result<(), Box<dyn std::error::Error>> {
87//! let response = arc_vector_client
88//!     .search_points(&SearchPoints {
89//!         collection_name: "my_collection".to_string(),
90//!         vector: vec![0.0_f32; 512],
91//!         limit: 4,
92//!         with_payload: Some(true.into()),
93//!         ..Default::default()
94//!     })
95//!     .await?;
96//!# Ok(())
97//!# }
98//! ```
99//!
100//! You can also add a `filters: Some(filters)` field to the
101//! [`SearchPoints`](crate::arc_vector::SearchPoints) argument to filter the
102//! result. See the [`Filter`](crate::arc_vector::Filter) documentation for
103//! details.
104
105mod channel_pool;
106pub mod client;
107pub mod prelude;
108// Do not lint/fmt code that is generated by tonic
109#[allow(clippy::all)]
110#[rustfmt::skip]
111pub mod arc_vector;
112pub mod filters;
113#[cfg(feature = "serde")]
114pub mod serde;
115
116use arc_vector::{value::Kind::*, ListValue, RetrievedPoint, ScoredPoint, Struct, Value};
117
118use std::error::Error;
119use std::fmt::{Debug, Display, Formatter};
120
121static NULL_VALUE: Value = Value {
122    kind: Some(NullValue(0)),
123};
124
125macro_rules! get_payload {
126    ($ty:ty) => {
127        impl $ty {
128            /// get a payload value for the specified key. If the key is not present,
129            /// this will return a null value.
130            ///
131            /// # Examples:
132            /// ```
133            #[doc = concat!("use arc_vector_rust::arc_vector::", stringify!($ty), ";")]
134            #[doc = concat!("let point = ", stringify!($ty), "::default();")]
135            /// assert!(point.get("not_present").is_null());
136            /// ````
137            pub fn get(&self, key: &str) -> &Value {
138                self.payload.get(key).unwrap_or(&NULL_VALUE)
139            }
140        }
141    };
142}
143
144get_payload!(RetrievedPoint);
145get_payload!(ScoredPoint);
146
147macro_rules! extract {
148    ($kind:ident, $check:ident) => {
149        /// check if this value is a
150        #[doc = stringify!($kind)]
151        pub fn $check(&self) -> bool {
152            matches!(self.kind, Some($kind(_)))
153        }
154    };
155    ($kind:ident, $check:ident, $extract:ident, $ty:ty) => {
156        extract!($kind, $check);
157
158        /// extract the contents if this value is a
159        #[doc = stringify!($kind)]
160        pub fn $extract(&self) -> Option<$ty> {
161            if let Some($kind(v)) = self.kind {
162                Some(v)
163            } else {
164                None
165            }
166        }
167    };
168    ($kind:ident, $check:ident, $extract:ident, ref $ty:ty) => {
169        extract!($kind, $check);
170
171        /// extract the contents if this value is a
172        #[doc = stringify!($kind)]
173        pub fn $extract(&self) -> Option<&$ty> {
174            if let Some($kind(v)) = &self.kind {
175                Some(v)
176            } else {
177                None
178            }
179        }
180    };
181}
182
183impl Value {
184    extract!(NullValue, is_null);
185    extract!(BoolValue, is_bool, as_bool, bool);
186    extract!(IntegerValue, is_integer, as_integer, i64);
187    extract!(DoubleValue, is_double, as_double, f64);
188    extract!(StringValue, is_str, as_str, ref String);
189    extract!(ListValue, is_list, as_list, ref [Value]);
190    extract!(StructValue, is_struct, as_struct, ref Struct);
191
192    #[cfg(feature = "serde")]
193    /// convert this into a `serde_json::Value`
194    ///
195    /// # Examples:
196    ///
197    /// ```
198    /// use serde_json::json;
199    /// use arc_vector_rust::prelude::*;
200    /// use arc_vector_rust::arc_vector::{value::Kind::*, Struct};
201    /// let value = Value { kind: Some(StructValue(Struct {
202    ///     fields: [
203    ///         ("text".into(), Value { kind: Some(StringValue("Hi ArcVector!".into())) }),
204    ///         ("int".into(), Value { kind: Some(IntegerValue(42))}),
205    ///     ].into()
206    /// }))};
207    /// assert_eq!(value.into_json(), json!({
208    ///    "text": "Hi ArcVector!",
209    ///    "int": 42
210    /// }));
211    /// ```
212    pub fn into_json(self) -> serde_json::Value {
213        use serde_json::Value as JsonValue;
214        match self.kind {
215            Some(BoolValue(b)) => JsonValue::Bool(b),
216            Some(IntegerValue(i)) => JsonValue::from(i),
217            Some(DoubleValue(d)) => JsonValue::from(d),
218            Some(StringValue(s)) => JsonValue::String(s),
219            Some(ListValue(vs)) => vs.into_iter().map(Value::into_json).collect(),
220            Some(StructValue(s)) => s
221                .fields
222                .into_iter()
223                .map(|(k, v)| (k, v.into_json()))
224                .collect(),
225            Some(NullValue(_)) | None => JsonValue::Null,
226        }
227    }
228}
229
230#[cfg(feature = "serde")]
231impl From<Value> for serde_json::Value {
232    fn from(value: Value) -> Self {
233        value.into_json()
234    }
235}
236
237impl Display for Value {
238    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
239        match &self.kind {
240            Some(BoolValue(b)) => write!(f, "{}", b),
241            Some(IntegerValue(i)) => write!(f, "{}", i),
242            Some(DoubleValue(v)) => write!(f, "{}", v),
243            Some(StringValue(s)) => write!(f, "{:?}", s),
244            Some(ListValue(vs)) => {
245                let mut i = vs.values.iter();
246                write!(f, "[")?;
247                if let Some(first) = i.next() {
248                    write!(f, "{}", first)?;
249                    for v in i {
250                        write!(f, ",{}", v)?;
251                    }
252                }
253                write!(f, "]")
254            }
255            Some(StructValue(s)) => {
256                let mut i = s.fields.iter();
257                write!(f, "{{")?;
258                if let Some((key, value)) = i.next() {
259                    write!(f, "{:?}:{}", key, value)?;
260                    for (key, value) in i {
261                        write!(f, ",{:?}:{}", key, value)?;
262                    }
263                }
264                write!(f, "}}")
265            }
266            _ => write!(f, "null"),
267        }
268    }
269}
270
271pub mod error {
272    use std::marker::PhantomData;
273
274    /// An error for failed conversions (e.g. calling `String::try_from(v)`
275    /// on an integer [`Value`](crate::Value))
276    pub struct NotA<T> {
277        marker: PhantomData<T>,
278    }
279
280    impl<T> Default for NotA<T> {
281        fn default() -> Self {
282            NotA {
283                marker: PhantomData,
284            }
285        }
286    }
287}
288
289use error::NotA;
290
291macro_rules! not_a {
292    ($ty:ty) => {
293        impl Error for NotA<$ty> {}
294
295        impl Debug for NotA<$ty> {
296            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
297                write!(f, "{}", self)
298            }
299        }
300
301        impl Display for NotA<$ty> {
302            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
303                f.write_str(concat!("not a ", stringify!($ty)))
304            }
305        }
306    };
307}
308
309macro_rules! impl_try_from {
310    ($ty:ty, $key:ident) => {
311        not_a!($ty);
312
313        impl std::convert::TryFrom<Value> for $ty {
314            type Error = NotA<$ty>;
315
316            fn try_from(v: Value) -> Result<Self, NotA<$ty>> {
317                if let Some($key(t)) = v.kind {
318                    Ok(t)
319                } else {
320                    Err(NotA::default())
321                }
322            }
323        }
324    };
325}
326
327impl_try_from!(bool, BoolValue);
328impl_try_from!(i64, IntegerValue);
329impl_try_from!(f64, DoubleValue);
330impl_try_from!(String, StringValue);
331
332not_a!(ListValue);
333not_a!(Struct);
334
335impl Value {
336    /// try to get an iterator over the items of the contained list value, if any
337    pub fn iter_list(&self) -> Result<impl Iterator<Item = &Value>, NotA<ListValue>> {
338        if let Some(ListValue(values)) = &self.kind {
339            Ok(values.iter())
340        } else {
341            Err(NotA::default())
342        }
343    }
344
345    /// try to get a field from the struct if this value contains one
346    pub fn get_struct(&self, key: &str) -> Result<&Value, NotA<Struct>> {
347        if let Some(StructValue(Struct { fields })) = &self.kind {
348            Ok(fields.get(key).unwrap_or(&NULL_VALUE))
349        } else {
350            Err(NotA::default())
351        }
352    }
353}
354
355impl std::ops::Deref for ListValue {
356    type Target = [Value];
357
358    fn deref(&self) -> &[Value] {
359        &self.values
360    }
361}
362
363impl IntoIterator for ListValue {
364    type Item = Value;
365
366    type IntoIter = std::vec::IntoIter<Value>;
367
368    fn into_iter(self) -> Self::IntoIter {
369        self.values.into_iter()
370    }
371}
372
373impl ListValue {
374    pub fn iter(&self) -> std::slice::Iter<'_, Value> {
375        self.values.iter()
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use crate::prelude::*;
382    use crate::arc_vector::value::Kind::*;
383    use crate::arc_vector::vectors_config::Config;
384    use crate::arc_vector::{
385        CreateFieldIndexCollection, FieldType, ListValue, Struct, Value, VectorParams,
386        VectorsConfig,
387    };
388    use std::collections::HashMap;
389
390    #[test]
391    fn display() {
392        let value = Value {
393            kind: Some(StructValue(Struct {
394                fields: [
395                    ("text", StringValue("Hi ArcVector!".into())),
396                    ("int", IntegerValue(42)),
397                    ("float", DoubleValue(1.23)),
398                    (
399                        "list",
400                        ListValue(ListValue {
401                            values: vec![Value {
402                                kind: Some(NullValue(0)),
403                            }],
404                        }),
405                    ),
406                    (
407                        "struct",
408                        StructValue(Struct {
409                            fields: [(
410                                "bool".into(),
411                                Value {
412                                    kind: Some(BoolValue(true)),
413                                },
414                            )]
415                            .into(),
416                        }),
417                    ),
418                ]
419                .into_iter()
420                .map(|(k, v)| (k.into(), Value { kind: Some(v) }))
421                .collect(),
422            })),
423        };
424        let text = format!("{}", value);
425        assert!([
426            "\"float\":1.23",
427            "\"list\":[null]",
428            "\"struct\":{\"bool\":true}",
429            "\"int\":42",
430            "\"text\":\"Hi ArcVector!\""
431        ]
432        .into_iter()
433        .all(|item| text.contains(item)));
434    }
435
436    #[tokio::test]
437    async fn test_arc_vector_queries() -> anyhow::Result<()> {
438        let config = ArcVectorClientConfig::from_url("http://localhost:6334");
439        let client = ArcVectorClient::new(Some(config))?;
440
441        let health = client.health_check().await?;
442        println!("{:?}", health);
443
444        let collections_list = client.list_collections().await?;
445        println!("{:?}", collections_list);
446
447        let collection_name = "test";
448        client.delete_collection(collection_name).await?;
449
450        client
451            .create_collection(&CreateCollection {
452                collection_name: collection_name.into(),
453                vectors_config: Some(VectorsConfig {
454                    config: Some(Config::Params(VectorParams {
455                        size: 10,
456                        distance: Distance::Cosine.into(),
457                        hnsw_config: None,
458                        quantization_config: None,
459                        on_disk: None,
460                    })),
461                }),
462                ..Default::default()
463            })
464            .await?;
465
466        let collection_info = client.collection_info(collection_name).await?;
467        println!("{:#?}", collection_info);
468
469        let mut sub_payload = Payload::new();
470        sub_payload.insert("foo", "Not bar");
471
472        let payload: Payload = vec![
473            ("foo", "Bar".into()),
474            ("bar", 12.into()),
475            ("sub_payload", sub_payload.into()),
476        ]
477        .into_iter()
478        .collect::<HashMap<_, Value>>()
479        .into();
480
481        let points = vec![PointStruct::new(0, vec![12.; 10], payload)];
482        client
483            .upsert_points_blocking(collection_name, points, None)
484            .await?;
485
486        let search_result = client
487            .search_points(&SearchPoints {
488                collection_name: collection_name.into(),
489                vector: vec![11.; 10],
490                filter: None,
491                limit: 10,
492                with_payload: Some(true.into()),
493                params: None,
494                score_threshold: None,
495                offset: None,
496                vector_name: None,
497                with_vectors: None,
498                read_consistency: None,
499            })
500            .await?;
501
502        eprintln!("search_result = {:#?}", search_result);
503
504        // Override payload of the existing point
505        let new_payload: Payload = vec![("foo", "BAZ".into())]
506            .into_iter()
507            .collect::<HashMap<_, Value>>()
508            .into();
509        client
510            .set_payload(collection_name, &vec![0.into()].into(), new_payload, None)
511            .await?;
512
513        // Delete some payload fields
514        client
515            .delete_payload_blocking(
516                collection_name,
517                &vec![0.into()].into(),
518                vec!["sub_payload".to_string()],
519                None,
520            )
521            .await?;
522
523        // retrieve points
524        let points = client
525            .get_points(collection_name, &[0.into()], Some(true), Some(true), None)
526            .await?;
527
528        assert_eq!(points.result.len(), 1);
529        let point = points.result[0].clone();
530        assert!(point.payload.contains_key("foo"));
531        assert!(!point.payload.contains_key("sub_payload"));
532
533        client
534            .delete_points(collection_name, &vec![0.into()].into(), None)
535            .await?;
536
537        // Access raw point api with client
538        client
539            .with_points_client(|mut client| async move {
540                client
541                    .create_field_index(CreateFieldIndexCollection {
542                        collection_name: collection_name.to_string(),
543                        wait: None,
544                        field_name: "foo".to_string(),
545                        field_type: Some(FieldType::Keyword as i32),
546                        field_index_params: None,
547                        ordering: None,
548                    })
549                    .await
550            })
551            .await?;
552
553        client.create_snapshot(collection_name).await?;
554        #[cfg(feature = "download_snapshots")]
555        client
556            .download_snapshot("test.tar", collection_name, None, None)
557            .await?;
558
559        Ok(())
560    }
561}