Skip to main content

qdrant_client/grpc_conversions/
extensions.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3
4#[cfg(feature = "uuid")]
5use uuid::Uuid;
6
7use crate::payload::Payload;
8#[cfg(feature = "uuid")]
9use crate::qdrant::point_id::PointIdOptions;
10use crate::qdrant::value::Kind;
11use crate::qdrant::{
12    HardwareUsage, InferenceUsage, ListValue, ModelUsage, PointId, PointStruct, RetrievedPoint,
13    ScoredPoint, Struct, Usage, Value, Vectors,
14};
15
16/// Null value
17static NULL_VALUE: Value = Value {
18    kind: Some(Kind::NullValue(0)),
19};
20
21impl PointStruct {
22    pub fn new(
23        id: impl Into<PointId>,
24        vectors: impl Into<Vectors>,
25        payload: impl Into<Payload>,
26    ) -> Self {
27        Self {
28            id: Some(id.into()),
29            payload: payload.into().into(),
30            vectors: Some(vectors.into()),
31        }
32    }
33}
34
35impl RetrievedPoint {
36    /// Get a payload value for the specified key. If the key is not present,
37    /// this will return a null value.
38    ///
39    /// # Examples:
40    ///
41    /// ```
42    /// use qdrant_client::qdrant::RetrievedPoint;
43    /// let point = RetrievedPoint::default();
44    /// assert!(point.get("not_present").is_null());
45    /// ````
46    pub fn get(&self, key: &str) -> &Value {
47        self.try_get(key).unwrap_or(&NULL_VALUE)
48    }
49
50    /// Try to get a payload value for the specified key. If the key is not present,
51    /// this will return `None`.
52    ///
53    /// # Examples:
54    ///
55    /// ```
56    /// use qdrant_client::qdrant::RetrievedPoint;
57    /// let point = RetrievedPoint::default();
58    /// assert_eq!(point.try_get("not_present"), None);
59    /// ````
60    pub fn try_get(&self, key: &str) -> Option<&Value> {
61        self.payload.get(key)
62    }
63}
64
65impl ScoredPoint {
66    /// Get a payload value for the specified key. If the key is not present,
67    /// this will return a null value.
68    ///
69    /// # Examples:
70    ///
71    /// ```
72    /// use qdrant_client::qdrant::ScoredPoint;
73    /// let point = ScoredPoint::default();
74    /// assert!(point.get("not_present").is_null());
75    /// ````
76    pub fn get(&self, key: &str) -> &Value {
77        self.try_get(key).unwrap_or(&NULL_VALUE)
78    }
79
80    /// Get a payload value for the specified key. If the key is not present,
81    /// this will return `None`.
82    ///
83    /// # Examples:
84    ///
85    /// ```
86    /// use qdrant_client::qdrant::ScoredPoint;
87    /// let point = ScoredPoint::default();
88    /// assert_eq!(point.try_get("not_present"), None);
89    /// ````
90    pub fn try_get(&self, key: &str) -> Option<&Value> {
91        self.payload.get(key)
92    }
93}
94
95macro_rules! extract {
96    ($kind:ident, $check:ident) => {
97        /// Check if this value is a
98        #[doc = stringify!([$kind])]
99        pub fn $check(&self) -> bool {
100            matches!(self.kind, Some($kind(_)))
101        }
102    };
103    ($kind:ident, $check:ident, $extract:ident, $ty:ty) => {
104        extract!($kind, $check);
105
106        /// Get this value as
107        #[doc = stringify!([$ty])]
108        ///
109        /// Returns `None` if this value is not a
110        #[doc = stringify!([$kind].)]
111        pub fn $extract(&self) -> Option<$ty> {
112            if let Some($kind(v)) = self.kind {
113                Some(v)
114            } else {
115                None
116            }
117        }
118    };
119    ($kind:ident, $check:ident, $extract:ident, ref $ty:ty) => {
120        extract!($kind, $check);
121
122        /// Get this value as
123        #[doc = stringify!([$ty])]
124        ///
125        /// Returns `None` if this value is not a
126        #[doc = stringify!([$kind].)]
127        pub fn $extract(&self) -> Option<&$ty> {
128            if let Some($kind(v)) = &self.kind {
129                Some(v)
130            } else {
131                None
132            }
133        }
134    };
135}
136
137// Separate module to not import all enum kinds of `Kind` directly as this conflicts with other types.
138// The macro extract!() however is built to take enum kinds directly and passing Kind::<kind> is not possible.
139mod value_extract_impl {
140    use crate::qdrant::value::Kind::*;
141    use crate::qdrant::{Struct, Value};
142    impl Value {
143        extract!(NullValue, is_null);
144        extract!(BoolValue, is_bool, as_bool, bool);
145        extract!(IntegerValue, is_integer, as_integer, i64);
146        extract!(DoubleValue, is_double, as_double, f64);
147        extract!(StringValue, is_str, as_str, ref String);
148        extract!(ListValue, is_list, as_list, ref [Value]);
149        extract!(StructValue, is_struct, as_struct, ref Struct);
150    }
151}
152
153impl Value {
154    #[cfg(feature = "serde")]
155    /// Convert this into a [`serde_json::Value`]
156    ///
157    /// # Examples:
158    ///
159    /// ```
160    /// use serde_json::json;
161    /// use qdrant_client::qdrant::{value::Kind::*, Struct, Value};
162    /// let value = Value { kind: Some(StructValue(Struct {
163    ///     fields: [
164    ///         ("text".into(), Value { kind: Some(StringValue("Hi Qdrant!".into())) }),
165    ///         ("int".into(), Value { kind: Some(IntegerValue(42))}),
166    ///     ].into()
167    /// }))};
168    /// assert_eq!(value.into_json(), json!({
169    ///    "text": "Hi Qdrant!",
170    ///    "int": 42
171    /// }));
172    /// ```
173    pub fn into_json(self) -> serde_json::Value {
174        use serde_json::Value as JsonValue;
175        match self.kind {
176            Some(Kind::BoolValue(b)) => JsonValue::Bool(b),
177            Some(Kind::IntegerValue(i)) => JsonValue::from(i),
178            Some(Kind::DoubleValue(d)) => JsonValue::from(d),
179            Some(Kind::StringValue(s)) => JsonValue::String(s),
180            Some(Kind::ListValue(vs)) => vs.into_iter().map(Value::into_json).collect(),
181            Some(Kind::StructValue(s)) => s
182                .fields
183                .into_iter()
184                .map(|(k, v)| (k, v.into_json()))
185                .collect(),
186            Some(Kind::NullValue(_)) | None => JsonValue::Null,
187        }
188    }
189}
190
191#[cfg(feature = "serde")]
192impl From<Value> for serde_json::Value {
193    fn from(value: Value) -> Self {
194        value.into_json()
195    }
196}
197
198impl Display for Value {
199    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
200        match &self.kind {
201            Some(Kind::BoolValue(b)) => write!(f, "{b}"),
202            Some(Kind::IntegerValue(i)) => write!(f, "{i}"),
203            Some(Kind::DoubleValue(v)) => write!(f, "{v}"),
204            Some(Kind::StringValue(s)) => write!(f, "{s:?}"),
205            Some(Kind::ListValue(vs)) => {
206                let mut i = vs.values.iter();
207                write!(f, "[")?;
208                if let Some(first) = i.next() {
209                    write!(f, "{first}")?;
210                    for v in i {
211                        write!(f, ",{v}")?;
212                    }
213                }
214                write!(f, "]")
215            }
216            Some(Kind::StructValue(s)) => {
217                let mut i = s.fields.iter();
218                write!(f, "{{")?;
219                if let Some((key, value)) = i.next() {
220                    write!(f, "{key:?}:{value}")?;
221                    for (key, value) in i {
222                        write!(f, ",{key:?}:{value}")?;
223                    }
224                }
225                write!(f, "}}")
226            }
227            _ => write!(f, "null"),
228        }
229    }
230}
231
232impl Value {
233    /// Try to get an iterator over the items of the contained list value
234    ///
235    /// Returns `None` if this is not a list.
236    pub fn try_list_iter(&self) -> Option<impl Iterator<Item = &Value>> {
237        if let Some(Kind::ListValue(values)) = &self.kind {
238            Some(values.iter())
239        } else {
240            None
241        }
242    }
243
244    /// Get a value from a struct field
245    ///
246    /// Returns `None` if this is not a struct type or if the field is not present.
247    pub fn get_value(&self, key: &str) -> Option<&Value> {
248        if let Some(Kind::StructValue(Struct { fields })) = &self.kind {
249            Some(fields.get(key)?)
250        } else {
251            None
252        }
253    }
254}
255
256impl std::ops::Deref for ListValue {
257    type Target = [Value];
258
259    fn deref(&self) -> &[Value] {
260        &self.values
261    }
262}
263
264impl IntoIterator for ListValue {
265    type Item = Value;
266
267    type IntoIter = std::vec::IntoIter<Value>;
268
269    fn into_iter(self) -> Self::IntoIter {
270        self.values.into_iter()
271    }
272}
273
274impl ListValue {
275    pub fn iter(&self) -> std::slice::Iter<'_, Value> {
276        self.values.iter()
277    }
278}
279
280#[cfg(feature = "uuid")]
281impl From<Uuid> for PointId {
282    fn from(uuid: Uuid) -> Self {
283        Self {
284            point_id_options: Some(PointIdOptions::from(uuid)),
285        }
286    }
287}
288
289#[cfg(feature = "uuid")]
290impl From<Uuid> for PointIdOptions {
291    fn from(uuid: Uuid) -> Self {
292        PointIdOptions::Uuid(uuid.to_string())
293    }
294}
295
296impl Hash for PointId {
297    fn hash<H: Hasher>(&self, state: &mut H) {
298        use crate::qdrant::point_id::PointIdOptions::{Num, Uuid};
299        match &self.point_id_options {
300            Some(Num(u)) => state.write_u64(*u),
301            Some(Uuid(s)) => s.hash(state),
302            None => {}
303        }
304    }
305}
306
307impl Hash for ScoredPoint {
308    fn hash<H: Hasher>(&self, state: &mut H) {
309        self.id.hash(state)
310    }
311}
312
313impl Hash for RetrievedPoint {
314    fn hash<H: Hasher>(&self, state: &mut H) {
315        self.id.hash(state)
316    }
317}
318
319impl Usage {
320    pub(crate) fn aggregate_opts(this: Option<Self>, other: Option<Self>) -> Option<Self> {
321        match (this, other) {
322            (Some(this), Some(other)) => Some(this.aggregate(other)),
323            (Some(this), None) => Some(this),
324            (None, Some(other)) => Some(other),
325            (None, None) => None,
326        }
327    }
328
329    pub(crate) fn aggregate(self, other: Self) -> Self {
330        Self {
331            hardware: HardwareUsage::aggregate_opts(self.hardware, other.hardware),
332            inference: InferenceUsage::aggregate_opts(self.inference, other.inference),
333        }
334    }
335}
336
337impl HardwareUsage {
338    pub(crate) fn aggregate_opts(this: Option<Self>, other: Option<Self>) -> Option<Self> {
339        match (this, other) {
340            (Some(this), Some(other)) => Some(this.aggregate(other)),
341            (Some(this), None) => Some(this),
342            (None, Some(other)) => Some(other),
343            (None, None) => None,
344        }
345    }
346
347    pub(crate) fn aggregate(self, other: Self) -> Self {
348        let Self {
349            cpu,
350            payload_io_read,
351            payload_io_write,
352            payload_index_io_read,
353            payload_index_io_write,
354            vector_io_read,
355            vector_io_write,
356        } = other;
357
358        Self {
359            cpu: self.cpu + cpu,
360            payload_io_read: self.payload_io_read + payload_io_read,
361            payload_io_write: self.payload_io_write + payload_io_write,
362            payload_index_io_read: self.payload_index_io_read + payload_index_io_read,
363            payload_index_io_write: self.payload_index_io_write + payload_index_io_write,
364            vector_io_read: self.vector_io_read + vector_io_read,
365            vector_io_write: self.vector_io_write + vector_io_write,
366        }
367    }
368}
369
370impl InferenceUsage {
371    pub(crate) fn aggregate_opts(this: Option<Self>, other: Option<Self>) -> Option<Self> {
372        match (this, other) {
373            (Some(this), Some(other)) => Some(this.aggregate(other)),
374            (Some(this), None) => Some(this),
375            (None, Some(other)) => Some(other),
376            (None, None) => None,
377        }
378    }
379
380    pub(crate) fn aggregate(self, other: Self) -> Self {
381        let mut models = self.models;
382        for (model_name, other_usage) in other.models {
383            models
384                .entry(model_name)
385                .and_modify(|usage| {
386                    *usage = usage.aggregate(other_usage);
387                })
388                .or_insert(other_usage);
389        }
390
391        Self { models }
392    }
393}
394
395impl ModelUsage {
396    pub(crate) fn aggregate(self, other: Self) -> Self {
397        Self {
398            tokens: self.tokens + other.tokens,
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use std::collections::HashMap;
406
407    use super::*;
408
409    #[test]
410    fn test_inference_usage_aggregation() {
411        let mut models1 = HashMap::new();
412        models1.insert("model_a".to_string(), ModelUsage { tokens: 100 });
413        models1.insert("model_b".to_string(), ModelUsage { tokens: 200 });
414
415        let mut models2 = HashMap::new();
416        models2.insert("model_a".to_string(), ModelUsage { tokens: 50 });
417        models2.insert("model_c".to_string(), ModelUsage { tokens: 300 });
418
419        let usage1 = InferenceUsage { models: models1 };
420        let usage2 = InferenceUsage { models: models2 };
421
422        let aggregated = usage1.aggregate(usage2);
423
424        // Check that model_a tokens were summed
425        assert_eq!(aggregated.models.get("model_a").unwrap().tokens, 150);
426
427        // Check that model_b was preserved
428        assert_eq!(aggregated.models.get("model_b").unwrap().tokens, 200);
429
430        // Check that model_c was added
431        assert_eq!(aggregated.models.get("model_c").unwrap().tokens, 300);
432
433        // Check that we have exactly 3 models
434        assert_eq!(aggregated.models.len(), 3);
435    }
436}