prusto_rs/models/
column.rs

1use std::borrow::Cow;
2use std::fmt;
3
4use serde::de::{self, MapAccess, Visitor};
5use serde::ser::SerializeStruct;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8use super::RawPrestoTy;
9
10#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
11#[serde(rename_all = "camelCase")]
12pub struct Column {
13    pub name: String,
14    #[serde(rename = "type")]
15    pub ty: String,
16    pub type_signature: Option<TypeSignature>,
17}
18
19#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
20#[serde(rename_all = "camelCase")]
21pub struct TypeSignature {
22    pub raw_type: RawPrestoTy,
23    pub arguments: Vec<ClientTypeSignatureParameter>,
24    #[serde(skip)]
25    type_arguments: (), // deprecated
26    #[serde(skip)]
27    literal_arguments: (), //deprecated
28}
29
30#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
31#[serde(rename_all = "camelCase")]
32pub struct NamedTypeSignature {
33    pub field_name: Option<RowFieldName>,
34    pub type_signature: TypeSignature,
35}
36
37#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
38#[serde(rename_all = "camelCase")]
39pub struct RowFieldName {
40    pub name: String,
41    #[serde(skip)]
42    delimited: (), // deprecated
43}
44
45#[derive(Clone, Debug, Eq, PartialEq)]
46pub enum ClientTypeSignatureParameter {
47    TypeSignature(TypeSignature),
48    NamedTypeSignature(NamedTypeSignature),
49    LongLiteral(u64),
50}
51
52impl Serialize for ClientTypeSignatureParameter {
53    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
54    where
55        S: Serializer,
56    {
57        use ClientTypeSignatureParameter::*;
58        let mut state = serializer.serialize_struct("ClientTypeSignatureParameter", 2)?;
59        match self {
60            TypeSignature(s) => {
61                state.serialize_field("kind", "TYPE")?;
62                state.serialize_field("value", s)?;
63            }
64            NamedTypeSignature(s) => {
65                state.serialize_field("kind", "NAMED_TYPE")?;
66                state.serialize_field("value", s)?;
67            }
68            LongLiteral(s) => {
69                state.serialize_field("kind", "LONG")?;
70                state.serialize_field("value", s)?;
71            }
72        };
73        state.end()
74    }
75}
76
77impl<'de> Deserialize<'de> for ClientTypeSignatureParameter {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: Deserializer<'de>,
81    {
82        #[derive(Deserialize)]
83        #[serde(field_identifier, rename_all = "lowercase")]
84        enum Field {
85            Kind,
86            Value,
87        }
88
89        struct ParamVisitor;
90
91        impl<'de> Visitor<'de> for ParamVisitor {
92            type Value = ClientTypeSignatureParameter;
93            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
94                formatter.write_str("struct ClientTypeSignatureParameter")
95            }
96
97            fn visit_map<V>(self, mut map: V) -> Result<ClientTypeSignatureParameter, V::Error>
98            where
99                V: MapAccess<'de>,
100            {
101                let kind = if let Some(Field::Kind) = map.next_key()? {
102                    // this is can't be `&str`
103                    // https://github.com/serde-rs/serde/issues/1009
104                    // https://github.com/serde-rs/serde/issues/1413#issuecomment-494892266
105                    map.next_value::<Cow<'_, str>>()?
106                } else {
107                    return Err(de::Error::missing_field("kind"));
108                };
109                if let Some(Field::Value) = map.next_key()? {
110                    match kind.as_ref() {
111                        "TYPE" | "TYPE_SIGNATURE" => {
112                            let v = map.next_value()?;
113                            Ok(ClientTypeSignatureParameter::TypeSignature(v))
114                        }
115                        "NAMED_TYPE" | "NAMED_TYPE_SIGNATURE" => {
116                            let v = map.next_value()?;
117                            Ok(ClientTypeSignatureParameter::NamedTypeSignature(v))
118                        }
119                        "LONG" | "LONG_LITERAL" => {
120                            let v = map.next_value()?;
121                            Ok(ClientTypeSignatureParameter::LongLiteral(v))
122                        }
123                        k => Err(de::Error::custom(format!("unknown kind: {}", k))),
124                    }
125                } else {
126                    Err(de::Error::missing_field("value"))
127                }
128            }
129        }
130
131        const FIELDS: &[&str] = &["kind", "value"];
132        deserializer.deserialize_struct("ClientTypeSignatureParameter", FIELDS, ParamVisitor)
133    }
134}
135
136impl TypeSignature {
137    pub fn new(raw_type: RawPrestoTy, arguments: Vec<ClientTypeSignatureParameter>) -> Self {
138        TypeSignature {
139            raw_type,
140            arguments,
141            type_arguments: (),
142            literal_arguments: (),
143        }
144    }
145}
146
147impl RowFieldName {
148    pub fn new(name: String) -> Self {
149        RowFieldName {
150            name,
151            delimited: (),
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_sig_varchar_de() {
162        let s = r#"
163        {
164                "rawType": "varchar",
165                "typeArguments": [],
166                "literalArguments": [],
167                "arguments": [
168                    {
169                        "kind": "LONG",
170                        "value": 2147483647
171                    }
172                ]
173        }
174        "#;
175
176        let s = serde_json::from_str::<TypeSignature>(&s).unwrap();
177        assert_eq!(
178            s,
179            TypeSignature {
180                raw_type: RawPrestoTy::VarChar,
181                arguments: vec![ClientTypeSignatureParameter::LongLiteral(2147483647)],
182                type_arguments: (),
183                literal_arguments: (),
184            }
185        );
186    }
187
188    #[test]
189    fn test_sig_ty_de() {
190        let s = r#"
191        {
192                "rawType": "map",
193                "typeArguments": [],
194                "literalArguments": [],
195                "arguments": [
196                    {
197                        "kind": "TYPE_SIGNATURE",
198                        "value": {
199                            "rawType": "varchar",
200                            "typeArguments": [],
201                            "literalArguments": [],
202                            "arguments": [
203                                {
204                                    "kind": "LONG",
205                                    "value": 3
206                                }
207                            ]
208                        }
209                    }
210                ]
211            }
212        "#;
213
214        let s = serde_json::from_str::<TypeSignature>(&s).unwrap();
215        assert_eq!(
216            s,
217            TypeSignature {
218                raw_type: RawPrestoTy::Map,
219                arguments: vec![ClientTypeSignatureParameter::TypeSignature(TypeSignature {
220                    raw_type: RawPrestoTy::VarChar,
221                    arguments: vec![ClientTypeSignatureParameter::LongLiteral(3)],
222                    type_arguments: (),
223                    literal_arguments: (),
224                })],
225                type_arguments: (),
226                literal_arguments: (),
227            }
228        );
229    }
230
231    #[test]
232    fn test_sig_named_ty_de() {
233        let s = r#"
234        {
235                "rawType": "row",
236                "typeArguments": [],
237                "literalArguments": [],
238                "arguments": [
239                    {
240                        "kind": "NAMED_TYPE_SIGNATURE",
241                        "value": {
242                            "fieldName": {
243                                "name": "y",
244                                "delimited": false
245                            },
246                            "typeSignature": {
247                                "rawType": "double",
248                                "typeArguments": [],
249                                "literalArguments": [],
250                                "arguments": []
251                            }
252                        }
253                    }
254                ]
255            }
256        "#;
257
258        let s = serde_json::from_str::<TypeSignature>(&s).unwrap();
259        assert_eq!(
260            s,
261            TypeSignature {
262                raw_type: RawPrestoTy::Row,
263                arguments: vec![ClientTypeSignatureParameter::NamedTypeSignature(
264                    NamedTypeSignature {
265                        field_name: Some(RowFieldName {
266                            name: "y".to_string(),
267                            delimited: (),
268                        }),
269                        type_signature: TypeSignature {
270                            raw_type: RawPrestoTy::Double,
271                            arguments: vec![],
272                            type_arguments: (),
273                            literal_arguments: (),
274                        }
275                    }
276                )],
277                type_arguments: (),
278                literal_arguments: (),
279            }
280        );
281    }
282
283    #[test]
284    fn test_sig_param() {
285        let s = r#"{"kind":"LONG","value":10}"#;
286        let res = serde_json::from_str::<ClientTypeSignatureParameter>(s).unwrap();
287        assert_eq!(res, ClientTypeSignatureParameter::LongLiteral(10));
288
289        let json = serde_json::to_value(res.clone()).unwrap();
290        let res2: ClientTypeSignatureParameter = serde_json::from_value(json).unwrap();
291        assert_eq!(res, res2)
292    }
293}