google_cloud_bigquery/
storage.rs

1use std::collections::VecDeque;
2use std::io::{BufReader, Cursor};
3
4use arrow::error::ArrowError;
5use arrow::ipc::reader::StreamReader;
6pub use arrow::*;
7
8use google_cloud_gax::grpc::{Status, Streaming};
9use google_cloud_gax::retry::RetrySetting;
10use google_cloud_googleapis::cloud::bigquery::storage::v1::read_rows_response::{Rows, Schema};
11use google_cloud_googleapis::cloud::bigquery::storage::v1::{
12    ArrowSchema, ReadRowsRequest, ReadRowsResponse, ReadSession,
13};
14
15use crate::grpc::apiv1::bigquery_client::StreamingReadClient;
16use crate::storage::value::StructDecodable;
17
18#[derive(thiserror::Error, Debug)]
19pub enum Error {
20    #[error(transparent)]
21    GRPC(#[from] Status),
22    #[error(transparent)]
23    ArrowNative(#[from] ArrowError),
24    #[error(transparent)]
25    Value(#[from] value::Error),
26    #[error("data format must be arrow")]
27    InvalidDataFormat,
28    #[error("schema format must be arrow")]
29    InvalidSchemaFormat,
30    #[error("no schema found in first response")]
31    NoSchemaFound,
32}
33
34pub struct Iterator<T>
35where
36    T: StructDecodable,
37{
38    client: StreamingReadClient,
39    session: ReadSession,
40    retry: Option<RetrySetting>,
41    // mutable
42    stream_index: usize,
43    current_stream: Streaming<ReadRowsResponse>,
44    chunk: VecDeque<T>,
45    schema: Option<ArrowSchema>,
46}
47
48impl<T> Iterator<T>
49where
50    T: StructDecodable,
51{
52    pub async fn new(
53        mut client: StreamingReadClient,
54        session: ReadSession,
55        retry: Option<RetrySetting>,
56    ) -> Result<Self, Error> {
57        let current_stream = client
58            .read_rows(
59                ReadRowsRequest {
60                    read_stream: session.streams[0].name.to_string(),
61                    offset: 0,
62                },
63                retry.clone(),
64            )
65            .await?
66            .into_inner();
67        Ok(Self {
68            client,
69            session,
70            retry,
71            current_stream,
72            stream_index: 0,
73            chunk: VecDeque::new(),
74            schema: None,
75        })
76    }
77
78    pub async fn next(&mut self) -> Result<Option<T>, Error> {
79        loop {
80            if let Some(row) = self.chunk.pop_front() {
81                return Ok(Some(row));
82            }
83            if let Some(rows) = self.current_stream.message().await? {
84                if self.schema.is_none() {
85                    match rows.schema.ok_or(Error::NoSchemaFound)? {
86                        Schema::ArrowSchema(schema) => self.schema = Some(schema),
87                        _ => return Err(Error::InvalidSchemaFormat),
88                    }
89                };
90                if let Some(rows) = rows.rows {
91                    self.chunk = rows_to_chunk(self.schema.clone().unwrap(), rows)?;
92                    return Ok(self.chunk.pop_front());
93                }
94            }
95
96            if self.stream_index == self.session.streams.len() - 1 {
97                return Ok(None);
98            } else {
99                self.stream_index += 1
100            }
101            let stream = &self.session.streams[self.stream_index].name;
102            self.current_stream = self
103                .client
104                .read_rows(
105                    ReadRowsRequest {
106                        read_stream: stream.to_string(),
107                        offset: 0,
108                    },
109                    self.retry.clone(),
110                )
111                .await?
112                .into_inner();
113        }
114    }
115}
116
117fn rows_to_chunk<T>(schema: ArrowSchema, rows: Rows) -> Result<VecDeque<T>, Error>
118where
119    T: StructDecodable,
120{
121    match rows {
122        Rows::ArrowRecordBatch(rows) => {
123            let mut rows_with_schema = schema.serialized_schema;
124            rows_with_schema.extend_from_slice(&rows.serialized_record_batch);
125            let rows = Cursor::new(rows_with_schema);
126            let rows: StreamReader<BufReader<Cursor<Vec<u8>>>> = StreamReader::try_new(BufReader::new(rows), None)?;
127            let mut chunk: VecDeque<T> = VecDeque::new();
128            for row in rows {
129                let row = row?;
130                for row_no in 0..row.num_rows() {
131                    chunk.push_back(T::decode_arrow(row.columns(), row_no)?)
132                }
133            }
134            Ok(chunk)
135        }
136        _ => Err(Error::InvalidDataFormat),
137    }
138}
139
140pub mod row {
141    use arrow::array::ArrayRef;
142
143    use crate::storage::value::{Decodable, StructDecodable};
144
145    #[derive(thiserror::Error, Debug)]
146    pub enum Error {
147        #[error("UnexpectedColumnIndex: {0}")]
148        UnexpectedColumnIndex(usize),
149        #[error(transparent)]
150        ArrowError(#[from] super::value::Error),
151    }
152
153    pub struct Row {
154        fields: Vec<ArrayRef>,
155        row_no: usize,
156    }
157
158    impl StructDecodable for Row {
159        fn decode_arrow(fields: &[ArrayRef], row_no: usize) -> Result<Row, super::value::Error> {
160            Ok(Self {
161                fields: fields.to_vec(),
162                row_no,
163            })
164        }
165    }
166
167    impl Row {
168        pub fn column<T: Decodable>(&self, index: usize) -> Result<T, Error> {
169            let column = self.fields.get(index).ok_or(Error::UnexpectedColumnIndex(index))?;
170            Ok(T::decode_arrow(column, self.row_no)?)
171        }
172    }
173}
174
175pub mod value {
176    use std::ops::Add;
177
178    use arrow::array::{
179        Array, ArrayRef, AsArray, BinaryArray, Date32Array, Decimal128Array, Decimal256Array, Float64Array, Int64Array,
180        ListArray, StringArray, Time64MicrosecondArray, TimestampMicrosecondArray,
181    };
182    use arrow::datatypes::{DataType, TimeUnit};
183    use bigdecimal::BigDecimal;
184    use time::macros::date;
185    use time::{Date, Duration, OffsetDateTime, Time};
186
187    #[derive(thiserror::Error, Debug)]
188    pub enum Error {
189        #[error("invalid data type actual={0}, expected={1}")]
190        InvalidDataType(DataType, &'static str),
191        #[error("invalid downcast dataType={0}")]
192        InvalidDowncast(DataType),
193        #[error("invalid non nullable")]
194        InvalidNullable,
195        #[error(transparent)]
196        InvalidTime(#[from] time::error::ComponentRange),
197        #[error(transparent)]
198        InvalidDecimal(#[from] bigdecimal::ParseBigDecimalError),
199    }
200
201    /// https://cloud.google.com/bigquery/docs/reference/storage#arrow_schema_details
202    pub trait Decodable: Sized {
203        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error>;
204    }
205
206    pub trait StructDecodable: Sized {
207        fn decode_arrow(fields: &[ArrayRef], row_no: usize) -> Result<Self, Error>;
208    }
209
210    impl<S> Decodable for S
211    where
212        S: StructDecodable,
213    {
214        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<S, Error> {
215            match col.data_type() {
216                DataType::Struct(_) => S::decode_arrow(downcast::<arrow::array::StructArray>(col)?.columns(), row_no),
217                _ => Err(Error::InvalidDataType(col.data_type().clone(), "struct")),
218            }
219        }
220    }
221
222    impl Decodable for bool {
223        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
224            if col.is_null(row_no) {
225                return Err(Error::InvalidNullable);
226            }
227            match col.data_type() {
228                DataType::Boolean => Ok(col.as_boolean().value(row_no)),
229                _ => Err(Error::InvalidDataType(col.data_type().clone(), "bool")),
230            }
231        }
232    }
233
234    impl Decodable for i64 {
235        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
236            if col.is_null(row_no) {
237                return Err(Error::InvalidNullable);
238            }
239            match col.data_type() {
240                DataType::Int64 => Ok(downcast::<Int64Array>(col)?.value(row_no)),
241                _ => Err(Error::InvalidDataType(col.data_type().clone(), "i64")),
242            }
243        }
244    }
245
246    impl Decodable for f64 {
247        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
248            if col.is_null(row_no) {
249                return Err(Error::InvalidNullable);
250            }
251            match col.data_type() {
252                DataType::Float64 => Ok(downcast::<Float64Array>(col)?.value(row_no)),
253                _ => Err(Error::InvalidDataType(col.data_type().clone(), "f64")),
254            }
255        }
256    }
257
258    impl Decodable for Vec<u8> {
259        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
260            if col.is_null(row_no) {
261                return Err(Error::InvalidNullable);
262            }
263            match col.data_type() {
264                DataType::Binary => Ok(downcast::<BinaryArray>(col)?.value(row_no).into()),
265                _ => Err(Error::InvalidDataType(col.data_type().clone(), "Vec<u8>")),
266            }
267        }
268    }
269
270    impl Decodable for String {
271        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
272            if col.is_null(row_no) {
273                return Err(Error::InvalidNullable);
274            }
275            match col.data_type() {
276                DataType::Decimal128(_, _) => BigDecimal::decode_arrow(col, row_no).map(|v| v.to_string()),
277                DataType::Decimal256(_, _) => BigDecimal::decode_arrow(col, row_no).map(|v| v.to_string()),
278                DataType::Date32 => Date::decode_arrow(col, row_no).map(|v| v.to_string()),
279                DataType::Timestamp(_, _) => OffsetDateTime::decode_arrow(col, row_no).map(|v| v.to_string()),
280                DataType::Time64(_) => Time::decode_arrow(col, row_no).map(|v| v.to_string()),
281                DataType::Boolean => bool::decode_arrow(col, row_no).map(|v| v.to_string()),
282                DataType::Float64 => f64::decode_arrow(col, row_no).map(|v| v.to_string()),
283                DataType::Int64 => i64::decode_arrow(col, row_no).map(|v| v.to_string()),
284                DataType::Utf8 => Ok(downcast::<StringArray>(col)?.value(row_no).to_string()),
285                _ => Err(Error::InvalidDataType(col.data_type().clone(), "String")),
286            }
287        }
288    }
289
290    impl Decodable for BigDecimal {
291        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
292            if col.is_null(row_no) {
293                return Err(Error::InvalidNullable);
294            }
295            match col.data_type() {
296                DataType::Decimal128(_, _) => {
297                    let decimal = downcast::<Decimal128Array>(col)?;
298                    let value = decimal.value(row_no);
299                    let bigint = num_bigint::BigInt::from_signed_bytes_le(&value.to_le_bytes());
300                    Ok(BigDecimal::from((bigint, decimal.scale() as i64)))
301                }
302                DataType::Decimal256(_, _) => {
303                    let decimal = downcast::<Decimal256Array>(col)?;
304                    let value = decimal.value(row_no);
305                    let bigint = num_bigint::BigInt::from_signed_bytes_le(&value.to_le_bytes());
306                    Ok(BigDecimal::from((bigint, decimal.scale() as i64)))
307                }
308                _ => Err(Error::InvalidDataType(col.data_type().clone(), "Decimal128")),
309            }
310        }
311    }
312
313    impl Decodable for Time {
314        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
315            if col.is_null(row_no) {
316                return Err(Error::InvalidNullable);
317            }
318            match col.data_type() {
319                DataType::Time64(TimeUnit::Microsecond) => {
320                    let micros = downcast::<Time64MicrosecondArray>(col)?.value(row_no);
321                    // ex) TIME(15, 30, 01) is 55801000000
322                    // ex) TIME_ADD(TIME(15, 30, 01), INTERVAL 10 MICROSECOND) is 55801000010
323                    let hour = micros / 3600_000000;
324                    let rest_micros = micros % 3600_000000;
325                    let minute = rest_micros / 60_000000;
326                    let rest_micros = rest_micros % 60_000000;
327                    let secs = rest_micros / 1_000_000;
328                    let rest_micros = rest_micros % 1_000_000;
329                    Ok(Time::from_hms_micro(hour as u8, minute as u8, secs as u8, rest_micros as u32)?)
330                }
331                _ => Err(Error::InvalidDataType(col.data_type().clone(), "Time")),
332            }
333        }
334    }
335
336    impl Decodable for Date {
337        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
338            if col.is_null(row_no) {
339                return Err(Error::InvalidNullable);
340            }
341            match col.data_type() {
342                DataType::Date32 => {
343                    let days_from_epoch = downcast::<Date32Array>(col)?.value(row_no);
344                    const UNIX_EPOCH: Date = date!(1970 - 01 - 01);
345                    Ok(UNIX_EPOCH.add(Duration::days(days_from_epoch as i64)))
346                }
347                _ => Err(Error::InvalidDataType(col.data_type().clone(), "DaysFromEpoch")),
348            }
349        }
350    }
351
352    impl Decodable for OffsetDateTime {
353        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Self, Error> {
354            if col.is_null(row_no) {
355                return Err(Error::InvalidNullable);
356            }
357            match col.data_type() {
358                DataType::Timestamp(TimeUnit::Microsecond, _) => {
359                    let micros = downcast::<TimestampMicrosecondArray>(col)?.value(row_no);
360                    Ok(OffsetDateTime::from_unix_timestamp_nanos(micros as i128 * 1000)?)
361                }
362                _ => Err(Error::InvalidDataType(col.data_type().clone(), "Days")),
363            }
364        }
365    }
366
367    impl<T> Decodable for Option<T>
368    where
369        T: Decodable,
370    {
371        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Option<T>, Error> {
372            if col.is_null(row_no) {
373                return Ok(None);
374            }
375            Ok(Some(T::decode_arrow(col, row_no)?))
376        }
377    }
378
379    impl<T> Decodable for Vec<T>
380    where
381        T: Decodable,
382    {
383        fn decode_arrow(col: &dyn Array, row_no: usize) -> Result<Vec<T>, Error> {
384            match col.data_type() {
385                DataType::List(_) => {
386                    let list = downcast::<ListArray>(col)?;
387                    let col = list.value(row_no);
388                    let mut result: Vec<T> = Vec::with_capacity(col.len());
389                    for row_num in 0..col.len() {
390                        result.push(T::decode_arrow(&col, row_num)?);
391                    }
392                    Ok(result)
393                }
394                _ => Err(Error::InvalidDataType(col.data_type().clone(), "Days")),
395            }
396        }
397    }
398
399    fn downcast<T: 'static>(col: &dyn Array) -> Result<&T, Error> {
400        col.as_any()
401            .downcast_ref::<T>()
402            .ok_or(Error::InvalidDowncast(col.data_type().clone()))
403    }
404}