1use std::collections::VecDeque;
2use std::io::{BufReader, Cursor};
3
4use arrow::error::ArrowError;
5use arrow::ipc::reader::StreamReader;
6pub use arrow::*;
7
8use google_cloud_gax::grpc::{Status, Streaming};
9use google_cloud_gax::retry::RetrySetting;
10use google_cloud_googleapis::cloud::bigquery::storage::v1::read_rows_response::{Rows, Schema};
11use google_cloud_googleapis::cloud::bigquery::storage::v1::{
12 ArrowSchema, ReadRowsRequest, ReadRowsResponse, ReadSession,
13};
14
15use crate::grpc::apiv1::bigquery_client::StreamingReadClient;
16use crate::storage::value::StructDecodable;
17
18#[derive(thiserror::Error, Debug)]
19pub enum Error {
20 #[error(transparent)]
21 GRPC(#[from] Status),
22 #[error(transparent)]
23 ArrowNative(#[from] ArrowError),
24 #[error(transparent)]
25 Value(#[from] value::Error),
26 #[error("data format must be arrow")]
27 InvalidDataFormat,
28 #[error("schema format must be arrow")]
29 InvalidSchemaFormat,
30 #[error("no schema found in first response")]
31 NoSchemaFound,
32}
33
34pub struct Iterator<T>
35where
36 T: StructDecodable,
37{
38 client: StreamingReadClient,
39 session: ReadSession,
40 retry: Option<RetrySetting>,
41 stream_index: usize,
43 current_stream: Streaming<ReadRowsResponse>,
44 chunk: VecDeque<T>,
45 schema: Option<ArrowSchema>,
46}
47
48impl<T> Iterator<T>
49where
50 T: StructDecodable,
51{
52 pub async fn new(
53 mut client: StreamingReadClient,
54 session: ReadSession,
55 retry: Option<RetrySetting>,
56 ) -> Result<Self, Error> {
57 let current_stream = client
58 .read_rows(
59 ReadRowsRequest {
60 read_stream: session.streams[0].name.to_string(),
61 offset: 0,
62 },
63 retry.clone(),
64 )
65 .await?
66 .into_inner();
67 Ok(Self {
68 client,
69 session,
70 retry,
71 current_stream,
72 stream_index: 0,
73 chunk: VecDeque::new(),
74 schema: None,
75 })
76 }
77
78 pub async fn next(&mut self) -> Result<Option<T>, Error> {
79 loop {
80 if let Some(row) = self.chunk.pop_front() {
81 return Ok(Some(row));
82 }
83 if let Some(rows) = self.current_stream.message().await? {
84 if self.schema.is_none() {
85 match rows.schema.ok_or(Error::NoSchemaFound)? {
86 Schema::ArrowSchema(schema) => self.schema = Some(schema),
87 _ => return Err(Error::InvalidSchemaFormat),
88 }
89 };
90 if let Some(rows) = rows.rows {
91 self.chunk = rows_to_chunk(self.schema.clone().unwrap(), rows)?;
92 return Ok(self.chunk.pop_front());
93 }
94 }
95
96 if self.stream_index == self.session.streams.len() - 1 {
97 return Ok(None);
98 } else {
99 self.stream_index += 1
100 }
101 let stream = &self.session.streams[self.stream_index].name;
102 self.current_stream = self
103 .client
104 .read_rows(
105 ReadRowsRequest {
106 read_stream: stream.to_string(),
107 offset: 0,
108 },
109 self.retry.clone(),
110 )
111 .await?
112 .into_inner();
113 }
114 }
115}
116
117fn rows_to_chunk<T>(schema: ArrowSchema, rows: Rows) -> Result<VecDeque<T>, Error>
118where
119 T: StructDecodable,
120{
121 match rows {
122 Rows::ArrowRecordBatch(rows) => {
123 let mut rows_with_schema = schema.serialized_schema;
124 rows_with_schema.extend_from_slice(&rows.serialized_record_batch);
125 let rows = Cursor::new(rows_with_schema);
126 let rows: StreamReader<BufReader<Cursor<Vec<u8>>>> = StreamReader::try_new(BufReader::new(rows), None)?;
127 let mut chunk: VecDeque<T> = VecDeque::new();
128 for row in rows {
129 let row = row?;
130 for row_no in 0..row.num_rows() {
131 chunk.push_back(T::decode_arrow(row.columns(), row_no)?)
132 }
133 }
134 Ok(chunk)
135 }
136 _ => Err(Error::InvalidDataFormat),
137 }
138}
139
140pub mod row {
141 use arrow::array::ArrayRef;
142
143 use crate::storage::value::{Decodable, StructDecodable};
144
145 #[derive(thiserror::Error, Debug)]
146 pub enum Error {
147 #[error("UnexpectedColumnIndex: {0}")]
148 UnexpectedColumnIndex(usize),
149 #[error(transparent)]
150 ArrowError(#[from] super::value::Error),
151 }
152
153 pub struct Row {
154 fields: Vec<ArrayRef>,
155 row_no: usize,
156 }
157
158 impl StructDecodable for Row {
159 fn decode_arrow(fields: &[ArrayRef], row_no: usize) -> Result<Row, super::value::Error> {
160 Ok(Self {
161 fields: fields.to_vec(),
162 row_no,
163 })
164 }
165 }
166
167 impl Row {
168 pub fn column<T: Decodable>(&self, index: usize) -> Result<T, Error> {
169 let column = self.fields.get(index).ok_or(Error::UnexpectedColumnIndex(index))?;
170 Ok(T::decode_arrow(column, self.row_no)?)
171 }
172 }
173}
174
175pub mod value {
176 use std::ops::Add;
177
178 use arrow::array::{
179 Array, ArrayRef, AsArray, BinaryArray, Date32Array, Decimal128Array, Decimal256Array, Float64Array, Int64Array,
180 ListArray, StringArray, Time64MicrosecondArray, TimestampMicrosecondArray,
181 };
182 use arrow::datatypes::{DataType, TimeUnit};
183 use bigdecimal::BigDecimal;
184 use time::macros::date;
185 use time::{Date, Duration, OffsetDateTime, Time};
186
187 #[derive(thiserror::Error, Debug)]
188 pub enum Error {
189 #[error("invalid data type actual={0}, expected={1}")]
190 InvalidDataType(DataType, &'static str),
191 #[error("invalid downcast dataType={0}")]
192 InvalidDowncast(DataType),
193 #[error("invalid non nullable")]
194 InvalidNullable,
195 #[error(transparent)]
196 InvalidTime(#[from] time::error::ComponentRange),
197 #[error(transparent)]
198 InvalidDecimal(#[from] bigdecimal::ParseBigDecimalError),
199 }
200
201 pub trait Decodable: Sized {
203 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error>;
204 }
205
206 pub trait StructDecodable: Sized {
207 fn decode_arrow(fields: &[ArrayRef], row_no: usize) -> Result<Self, Error>;
208 }
209
210 impl<S> Decodable for S
211 where
212 S: StructDecodable,
213 {
214 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<S, Error> {
215 match col.data_type() {
216 DataType::Struct(_) => S::decode_arrow(downcast::<arrow::array::StructArray>(col)?.columns(), row_no),
217 _ => Err(Error::InvalidDataType(col.data_type().clone(), "struct")),
218 }
219 }
220 }
221
222 impl Decodable for bool {
223 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
224 if col.is_null(row_no) {
225 return Err(Error::InvalidNullable);
226 }
227 match col.data_type() {
228 DataType::Boolean => Ok(col.as_boolean().value(row_no)),
229 _ => Err(Error::InvalidDataType(col.data_type().clone(), "bool")),
230 }
231 }
232 }
233
234 impl Decodable for i64 {
235 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
236 if col.is_null(row_no) {
237 return Err(Error::InvalidNullable);
238 }
239 match col.data_type() {
240 DataType::Int64 => Ok(downcast::<Int64Array>(col)?.value(row_no)),
241 _ => Err(Error::InvalidDataType(col.data_type().clone(), "i64")),
242 }
243 }
244 }
245
246 impl Decodable for f64 {
247 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
248 if col.is_null(row_no) {
249 return Err(Error::InvalidNullable);
250 }
251 match col.data_type() {
252 DataType::Float64 => Ok(downcast::<Float64Array>(col)?.value(row_no)),
253 _ => Err(Error::InvalidDataType(col.data_type().clone(), "f64")),
254 }
255 }
256 }
257
258 impl Decodable for Vec<u8> {
259 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
260 if col.is_null(row_no) {
261 return Err(Error::InvalidNullable);
262 }
263 match col.data_type() {
264 DataType::Binary => Ok(downcast::<BinaryArray>(col)?.value(row_no).into()),
265 _ => Err(Error::InvalidDataType(col.data_type().clone(), "Vec<u8>")),
266 }
267 }
268 }
269
270 impl Decodable for String {
271 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
272 if col.is_null(row_no) {
273 return Err(Error::InvalidNullable);
274 }
275 match col.data_type() {
276 DataType::Decimal128(_, _) => BigDecimal::decode_arrow(col, row_no).map(|v| v.to_string()),
277 DataType::Decimal256(_, _) => BigDecimal::decode_arrow(col, row_no).map(|v| v.to_string()),
278 DataType::Date32 => Date::decode_arrow(col, row_no).map(|v| v.to_string()),
279 DataType::Timestamp(_, _) => OffsetDateTime::decode_arrow(col, row_no).map(|v| v.to_string()),
280 DataType::Time64(_) => Time::decode_arrow(col, row_no).map(|v| v.to_string()),
281 DataType::Boolean => bool::decode_arrow(col, row_no).map(|v| v.to_string()),
282 DataType::Float64 => f64::decode_arrow(col, row_no).map(|v| v.to_string()),
283 DataType::Int64 => i64::decode_arrow(col, row_no).map(|v| v.to_string()),
284 DataType::Utf8 => Ok(downcast::<StringArray>(col)?.value(row_no).to_string()),
285 _ => Err(Error::InvalidDataType(col.data_type().clone(), "String")),
286 }
287 }
288 }
289
290 impl Decodable for BigDecimal {
291 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
292 if col.is_null(row_no) {
293 return Err(Error::InvalidNullable);
294 }
295 match col.data_type() {
296 DataType::Decimal128(_, _) => {
297 let decimal = downcast::<Decimal128Array>(col)?;
298 let value = decimal.value(row_no);
299 let bigint = num_bigint::BigInt::from_signed_bytes_le(&value.to_le_bytes());
300 Ok(BigDecimal::from((bigint, decimal.scale() as i64)))
301 }
302 DataType::Decimal256(_, _) => {
303 let decimal = downcast::<Decimal256Array>(col)?;
304 let value = decimal.value(row_no);
305 let bigint = num_bigint::BigInt::from_signed_bytes_le(&value.to_le_bytes());
306 Ok(BigDecimal::from((bigint, decimal.scale() as i64)))
307 }
308 _ => Err(Error::InvalidDataType(col.data_type().clone(), "Decimal128")),
309 }
310 }
311 }
312
313 impl Decodable for Time {
314 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
315 if col.is_null(row_no) {
316 return Err(Error::InvalidNullable);
317 }
318 match col.data_type() {
319 DataType::Time64(TimeUnit::Microsecond) => {
320 let micros = downcast::<Time64MicrosecondArray>(col)?.value(row_no);
321 let hour = micros / 3600_000000;
324 let rest_micros = micros % 3600_000000;
325 let minute = rest_micros / 60_000000;
326 let rest_micros = rest_micros % 60_000000;
327 let secs = rest_micros / 1_000_000;
328 let rest_micros = rest_micros % 1_000_000;
329 Ok(Time::from_hms_micro(hour as u8, minute as u8, secs as u8, rest_micros as u32)?)
330 }
331 _ => Err(Error::InvalidDataType(col.data_type().clone(), "Time")),
332 }
333 }
334 }
335
336 impl Decodable for Date {
337 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
338 if col.is_null(row_no) {
339 return Err(Error::InvalidNullable);
340 }
341 match col.data_type() {
342 DataType::Date32 => {
343 let days_from_epoch = downcast::<Date32Array>(col)?.value(row_no);
344 const UNIX_EPOCH: Date = date!(1970 - 01 - 01);
345 Ok(UNIX_EPOCH.add(Duration::days(days_from_epoch as i64)))
346 }
347 _ => Err(Error::InvalidDataType(col.data_type().clone(), "DaysFromEpoch")),
348 }
349 }
350 }
351
352 impl Decodable for OffsetDateTime {
353 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
354 if col.is_null(row_no) {
355 return Err(Error::InvalidNullable);
356 }
357 match col.data_type() {
358 DataType::Timestamp(TimeUnit::Microsecond, _) => {
359 let micros = downcast::<TimestampMicrosecondArray>(col)?.value(row_no);
360 Ok(OffsetDateTime::from_unix_timestamp_nanos(micros as i128 * 1000)?)
361 }
362 _ => Err(Error::InvalidDataType(col.data_type().clone(), "Days")),
363 }
364 }
365 }
366
367 impl<T> Decodable for Option<T>
368 where
369 T: Decodable,
370 {
371 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Option<T>, Error> {
372 if col.is_null(row_no) {
373 return Ok(None);
374 }
375 Ok(Some(T::decode_arrow(col, row_no)?))
376 }
377 }
378
379 impl<T> Decodable for Vec<T>
380 where
381 T: Decodable,
382 {
383 fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Vec<T>, Error> {
384 match col.data_type() {
385 DataType::List(_) => {
386 let list = downcast::<ListArray>(col)?;
387 let col = list.value(row_no);
388 let mut result: Vec<T> = Vec::with_capacity(col.len());
389 for row_num in 0..col.len() {
390 result.push(T::decode_arrow(&col, row_num)?);
391 }
392 Ok(result)
393 }
394 _ => Err(Error::InvalidDataType(col.data_type().clone(), "Days")),
395 }
396 }
397 }
398
399 fn downcast<T: 'static>(col: &dyn Array) -> Result<&T, Error> {
400 col.as_any()
401 .downcast_ref::<T>()
402 .ok_or(Error::InvalidDowncast(col.data_type().clone()))
403 }
404}