prusto_rs/types/
data_set.rs1use std::fmt;
2use std::marker::PhantomData;
3
4use iterable::Iterable;
5use serde::de::{self, Deserializer, MapAccess, Visitor};
6use serde::ser::{SerializeStruct, Serializer};
7use serde::{Deserialize, Serialize};
8
9use super::util::SerializeIterator;
10use super::{Context, Error, Presto, PrestoTy, VecSeed};
11use crate::models::Column;
12use crate::Row;
13
14#[derive(Debug)]
15pub struct DataSet<T: Presto> {
16 types: Vec<(String, PrestoTy)>,
17 data: Vec<T>,
18}
19
20impl<T: Presto> DataSet<T> {
21 pub fn new(data: Vec<T>) -> Result<Self, Error> {
22 let types = match T::ty() {
23 PrestoTy::Row(r) => {
24 if r.is_empty() {
25 return Err(Error::EmptyInPrestoRow);
26 } else {
27 r
28 }
29 }
30 _ => return Err(Error::NonePrestoRow),
31 };
32
33 Ok(DataSet { types, data })
34 }
35
36 pub fn split(self) -> (Vec<(String, PrestoTy)>, Vec<T>) {
37 (self.types, self.data)
38 }
39
40 pub fn into_vec(self) -> Vec<T> {
41 self.data
42 }
43
44 pub fn is_empty(&self) -> bool {
45 self.data.is_empty()
46 }
47
48 pub fn len(&self) -> usize {
49 self.data.len()
50 }
51
52 pub fn as_slice(&self) -> &[T] {
53 self.data.as_slice()
54 }
55
56 pub fn merge(&mut self, other: DataSet<T>) {
57 self.data.extend(other.data)
58 }
59}
60
61impl DataSet<Row> {
62 pub fn new_row(types: Vec<(String, PrestoTy)>, data: Vec<Row>) -> Result<Self, Error> {
63 if types.is_empty() {
64 return Err(Error::EmptyInPrestoRow);
65 }
66 Ok(DataSet { types, data })
67 }
68}
69
70impl<T: Presto + Clone> Clone for DataSet<T> {
71 fn clone(&self) -> Self {
72 DataSet {
73 types: self.types.clone(),
74 data: self.data.clone(),
75 }
76 }
77}
78
79impl<T: Presto> Serialize for DataSet<T> {
83 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84 where
85 S: Serializer,
86 {
87 let mut state = serializer.serialize_struct("DataSet", 2)?;
88
89 let columns = self.types.clone().map(|(name, ty)| Column {
90 name,
91 ty: ty.full_type().into_owned(),
92 type_signature: Some(ty.into_type_signature()),
93 });
94
95 let data = SerializeIterator {
96 iter: self.data.iter().map(|d| d.value()),
97 size: Some(self.data.len()),
98 };
99 state.serialize_field("columns", &columns)?;
100 state.serialize_field("data", &data)?;
101 state.end()
102 }
103}
104
105#[derive(Deserialize)]
109#[serde(field_identifier, rename_all = "lowercase")]
110enum Field {
111 Columns,
112 Data,
113}
114
115const FIELDS: &[&str] = &["columns", "data"];
116
117impl<'de, T: Presto> Deserialize<'de> for DataSet<T> {
118 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119 where
120 D: Deserializer<'de>,
121 {
122 struct DataSetVisitor<T: Presto>(PhantomData<T>);
123
124 impl<'de, T: Presto> Visitor<'de> for DataSetVisitor<T> {
125 type Value = DataSet<T>;
126 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
127 formatter.write_str("struct DataSet")
128 }
129
130 fn visit_map<V>(self, mut map: V) -> Result<DataSet<T>, V::Error>
131 where
132 V: MapAccess<'de>,
133 {
134 let types = if let Some(Field::Columns) = map.next_key()? {
135 let columns: Vec<Column> = map.next_value()?;
136 columns.try_map(PrestoTy::from_column).map_err(|e| {
137 de::Error::custom(format!("deserialize presto type failed, reason: {}", e))
138 })?
139 } else {
140 return Err(de::Error::missing_field("columns"));
141 };
142
143 let array_ty = PrestoTy::Array(Box::new(PrestoTy::Row(types.clone())));
144 let ctx = Context::new::<Vec<T>>(&array_ty).map_err(|e| {
145 de::Error::custom(format!("invalid presto type, reason: {}", e))
146 })?;
147 let seed = VecSeed::new(&ctx);
148
149 let data = if let Some(Field::Data) = map.next_key()? {
150 map.next_value_seed(seed)?
151 } else {
152 vec![]
154 };
155
156 match map.next_key::<Field>()? {
157 Some(Field::Columns) => return Err(de::Error::duplicate_field("columns")),
158 Some(Field::Data) => return Err(de::Error::duplicate_field("data")),
159 None => {}
160 }
161
162 if let PrestoTy::Unknown = T::ty() {
163 Ok(DataSet { types, data })
164 } else {
165 DataSet::new(data).map_err(|e| {
166 de::Error::custom(format!("construct data failed, reason: {}", e))
167 })
168 }
169 }
170 }
171
172 deserializer.deserialize_struct("DataSet", FIELDS, DataSetVisitor(PhantomData))
173 }
174}