Skip to main content

firestore/firestore_serde/
vector_serializers.rs

1use crate::errors::FirestoreError;
2use crate::firestore_serde::serializer::FirestoreValueSerializer;
3use crate::FirestoreValue;
4use serde::de::{MapAccess, Visitor};
5use serde::{Deserializer, Serialize};
6
7pub(crate) const FIRESTORE_VECTOR_TYPE_TAG_TYPE: &str = "FirestoreVector";
8
9#[derive(Serialize, Clone, Debug, PartialEq, PartialOrd, Default)]
10pub struct FirestoreVector(pub Vec<f64>);
11
12impl FirestoreVector {
13    pub fn new(vec: Vec<f64>) -> Self {
14        FirestoreVector(vec)
15    }
16
17    pub fn into_vec(self) -> Vec<f64> {
18        self.0
19    }
20
21    pub fn as_vec(&self) -> &Vec<f64> {
22        &self.0
23    }
24}
25
26impl From<FirestoreVector> for Vec<f64> {
27    fn from(val: FirestoreVector) -> Self {
28        val.into_vec()
29    }
30}
31
32impl<I> From<I> for FirestoreVector
33where
34    I: IntoIterator<Item = f64>,
35{
36    fn from(vec: I) -> Self {
37        FirestoreVector(vec.into_iter().collect())
38    }
39}
40
41pub fn serialize_vector_for_firestore<T: ?Sized + Serialize>(
42    firestore_value_serializer: FirestoreValueSerializer,
43    value: &T,
44) -> Result<FirestoreValue, FirestoreError> {
45    let value_with_array = value.serialize(firestore_value_serializer)?;
46
47    Ok(FirestoreValue::from(
48        gcloud_sdk::google::firestore::v1::Value {
49            value_type: Some(gcloud_sdk::google::firestore::v1::value::ValueType::MapValue(
50                gcloud_sdk::google::firestore::v1::MapValue {
51                    fields: vec![
52                        (
53                            "__type__".to_string(),
54                            gcloud_sdk::google::firestore::v1::Value {
55                                value_type: Some(gcloud_sdk::google::firestore::v1::value::ValueType::StringValue(
56                                    "__vector__".to_string()
57                                )),
58                            }
59                        ),
60                        (
61                            "value".to_string(),
62                            value_with_array.value
63                        )].into_iter().collect()
64                }
65            ))
66        }),
67    )
68}
69
70struct FirestoreVectorVisitor;
71
72impl<'de> Visitor<'de> for FirestoreVectorVisitor {
73    type Value = FirestoreVector;
74
75    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
76        formatter.write_str("a FirestoreVector")
77    }
78
79    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
80    where
81        A: serde::de::SeqAccess<'de>,
82    {
83        let mut vec = Vec::new();
84
85        while let Some(value) = seq.next_element()? {
86            vec.push(value);
87        }
88
89        Ok(FirestoreVector(vec))
90    }
91
92    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
93    where
94        A: MapAccess<'de>,
95    {
96        while let Some(field) = map.next_key::<String>()? {
97            match field.as_str() {
98                "__type__" => {
99                    let value = map.next_value::<String>()?;
100                    if value != "__vector__" {
101                        return Err(serde::de::Error::custom(
102                            "Expected __vector__  for FirestoreVector",
103                        ));
104                    }
105                }
106                "value" => {
107                    let value = map.next_value::<Vec<f64>>()?;
108                    return Ok(FirestoreVector(value));
109                }
110                _ => {
111                    return Err(serde::de::Error::custom(
112                        "Unknown field for FirestoreVector",
113                    ));
114                }
115            }
116        }
117        Err(serde::de::Error::custom(
118            "Unknown structure for FirestoreVector",
119        ))
120    }
121}
122
123impl<'de> serde::Deserialize<'de> for FirestoreVector {
124    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
125    where
126        D: Deserializer<'de>,
127    {
128        deserializer.deserialize_any(FirestoreVectorVisitor)
129    }
130}