1use super::{decode_and_convert, decode_multiple_elements, ProvableResultColumn, QueryError};
2use crate::base::{
3 database::{Column, ColumnField, ColumnType, OwnedColumn, OwnedTable, Table},
4 polynomial::compute_evaluation_vector,
5 scalar::{Scalar, ScalarExt},
6};
7use alloc::{vec, vec::Vec};
8use num_traits::Zero;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Default, Clone, Serialize, Deserialize)]
14pub struct ProvableQueryResult {
15 num_columns: u64,
16 pub(crate) table_length: u64,
17 data: Vec<u8>,
18}
19
20impl ProvableQueryResult {
22 #[allow(clippy::cast_possible_truncation)]
23 #[must_use]
25 pub fn num_columns(&self) -> usize {
26 self.num_columns as usize
27 }
28 #[cfg(test)]
31 pub fn num_columns_mut(&mut self) -> &mut u64 {
32 &mut self.num_columns
33 }
34
35 #[allow(clippy::cast_possible_truncation)]
36 #[must_use]
38 pub fn table_length(&self) -> usize {
39 self.table_length as usize
40 }
41 #[cfg(test)]
44 pub fn data_mut(&mut self) -> &mut Vec<u8> {
45 &mut self.data
46 }
47 #[cfg(test)]
49 #[must_use]
50 pub fn new_from_raw_data(num_columns: u64, table_length: u64, data: Vec<u8>) -> Self {
51 Self {
52 num_columns,
53 table_length,
54 data,
55 }
56 }
57
58 #[must_use]
64 pub fn new<'a, S: Scalar>(table_length: u64, columns: &'a [Column<'a, S>]) -> Self {
65 assert!(columns
66 .iter()
67 .all(|column| table_length == column.len() as u64));
68 let mut sz = 0;
69 for col in columns {
70 sz += col.num_bytes(table_length);
71 }
72 let mut data = vec![0u8; sz];
73 let mut sz = 0;
74 for col in columns {
75 sz += col.write(&mut data[sz..], table_length);
76 }
77 ProvableQueryResult {
78 num_columns: columns.len() as u64,
79 table_length,
80 data,
81 }
82 }
83
84 #[allow(clippy::cast_possible_truncation)]
85 #[allow(
86 clippy::missing_panics_doc,
87 reason = "Assertions ensure preconditions are met, eliminating the possibility of panic."
88 )]
89 pub fn evaluate<S: Scalar>(
96 &self,
97 evaluation_point: &[S],
98 output_length: usize,
99 column_result_fields: &[ColumnField],
100 ) -> Result<Vec<S>, QueryError> {
101 if self.num_columns as usize != column_result_fields.len() {
102 return Err(QueryError::InvalidColumnCount);
103 }
104 let mut evaluation_vec = vec![Zero::zero(); output_length];
105 compute_evaluation_vector(&mut evaluation_vec, evaluation_point);
106 let mut offset: usize = 0;
107 let mut res = Vec::with_capacity(self.num_columns as usize);
108
109 for field in column_result_fields {
110 let mut val = S::zero();
111 for entry in evaluation_vec.iter().take(output_length) {
112 let (x, sz) = match field.data_type() {
113 ColumnType::Boolean => decode_and_convert::<bool, S>(&self.data[offset..]),
114 ColumnType::Uint8 => decode_and_convert::<u8, S>(&self.data[offset..]),
115 ColumnType::TinyInt => decode_and_convert::<i8, S>(&self.data[offset..]),
116 ColumnType::SmallInt => decode_and_convert::<i16, S>(&self.data[offset..]),
117 ColumnType::Int => decode_and_convert::<i32, S>(&self.data[offset..]),
118 ColumnType::BigInt => decode_and_convert::<i64, S>(&self.data[offset..]),
119 ColumnType::Int128 => decode_and_convert::<i128, S>(&self.data[offset..]),
120 ColumnType::Decimal75(_, _) | ColumnType::Scalar => {
121 decode_and_convert::<S, S>(&self.data[offset..])
122 }
123
124 ColumnType::VarChar => decode_and_convert::<&str, S>(&self.data[offset..]),
125 ColumnType::VarBinary => {
126 let (raw_bytes, used) =
127 decode_and_convert::<&[u8], &[u8]>(&self.data[offset..])?;
128 let x = S::from_byte_slice_via_hash(raw_bytes);
129 Ok((x, used))
130 }
131 ColumnType::TimestampTZ(_, _) => {
132 decode_and_convert::<i64, S>(&self.data[offset..])
133 }
134 }?;
135 val += *entry * x;
136 offset += sz;
137 }
138 res.push(val);
139 }
140 if offset != self.data.len() {
141 return Err(QueryError::MiscellaneousEvaluationError);
142 }
143
144 Ok(res)
145 }
146
147 #[allow(
148 clippy::missing_panics_doc,
149 reason = "Assertions ensure preconditions are met, eliminating the possibility of panic."
150 )]
151 pub fn to_owned_table<S: Scalar>(
155 &self,
156 column_result_fields: &[ColumnField],
157 ) -> Result<OwnedTable<S>, QueryError> {
158 if column_result_fields.len() != self.num_columns() {
159 return Err(QueryError::InvalidColumnCount);
160 }
161
162 let n = self.table_length();
163 let mut offset: usize = 0;
164
165 let owned_table = OwnedTable::try_new(
166 column_result_fields
167 .iter()
168 .map(|field| match field.data_type() {
169 ColumnType::Boolean => {
170 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
171 offset += num_read;
172 Ok((field.name(), OwnedColumn::Boolean(col)))
173 }
174 ColumnType::Uint8 => {
175 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
176 offset += num_read;
177 Ok((field.name(), OwnedColumn::Uint8(col)))
178 }
179 ColumnType::TinyInt => {
180 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
181 offset += num_read;
182 Ok((field.name(), OwnedColumn::TinyInt(col)))
183 }
184 ColumnType::SmallInt => {
185 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
186 offset += num_read;
187 Ok((field.name(), OwnedColumn::SmallInt(col)))
188 }
189 ColumnType::Int => {
190 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
191 offset += num_read;
192 Ok((field.name(), OwnedColumn::Int(col)))
193 }
194 ColumnType::BigInt => {
195 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
196 offset += num_read;
197 Ok((field.name(), OwnedColumn::BigInt(col)))
198 }
199 ColumnType::Int128 => {
200 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
201 offset += num_read;
202 Ok((field.name(), OwnedColumn::Int128(col)))
203 }
204 ColumnType::VarChar => {
205 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
206 offset += num_read;
207 Ok((field.name(), OwnedColumn::VarChar(col)))
208 }
209 ColumnType::VarBinary => {
210 let (decoded_slices, num_read) =
212 decode_multiple_elements::<&[u8]>(&self.data[offset..], n)?;
213 offset += num_read;
214
215 let col_vec = decoded_slices.into_iter().map(<[u8]>::to_vec).collect();
217
218 Ok((field.name(), OwnedColumn::VarBinary(col_vec)))
219 }
220 ColumnType::Scalar => {
221 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
222 offset += num_read;
223 Ok((field.name(), OwnedColumn::Scalar(col)))
224 }
225 ColumnType::Decimal75(precision, scale) => {
226 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
227 offset += num_read;
228 Ok((field.name(), OwnedColumn::Decimal75(precision, scale, col)))
229 }
230 ColumnType::TimestampTZ(tu, tz) => {
231 let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?;
232 offset += num_read;
233 Ok((field.name(), OwnedColumn::TimestampTZ(tu, tz, col)))
234 }
235 })
236 .collect::<Result<_, QueryError>>()?,
237 )?;
238
239 assert_eq!(offset, self.data.len());
240 assert_eq!(owned_table.num_columns(), self.num_columns());
241
242 Ok(owned_table)
243 }
244}
245
246impl<S: Scalar> From<Table<'_, S>> for ProvableQueryResult {
247 fn from(table: Table<S>) -> Self {
248 let num_rows = table.num_rows();
249 let columns = table
250 .into_inner()
251 .into_iter()
252 .map(|(_, col)| col)
253 .collect::<Vec<_>>();
254 Self::new(num_rows as u64, &columns)
255 }
256}