1use tonic::Streaming;
2
3use futures::future::ready;
4use futures::stream::{StreamExt, TryStreamExt};
5
6use std::io::Cursor;
7
8use crate::googleapis::{
9 read_rows_response::Rows, read_session::Schema, ArrowRecordBatch, ArrowSchema, ReadRowsResponse,
10};
11use crate::Error;
12
13#[cfg(feature = "arrow")]
14use arrow::ipc::reader::StreamReader as ArrowStreamReader;
15
16#[cfg(feature = "arrow")]
18fn strip_continuation_bytes(msg: &[u8]) -> Result<&[u8], Error> {
19 let header = msg
20 .get(0..4)
21 .ok_or(Error::invalid("arrow message of invalid len"))?;
22 if header != [255; 4] {
23 Err(Error::invalid("invalid arrow message"))
24 } else {
25 let tail = msg.get(4..).ok_or(Error::invalid("empty arrow message"))?;
26 Ok(tail)
27 }
28}
29
30#[cfg(feature = "arrow")]
31pub type DefaultArrowStreamReader = ArrowStreamReader<Cursor<Vec<u8>>>;
32
33pub struct RowsStreamReader {
35 schema: Schema,
36 upstream: Streaming<ReadRowsResponse>,
37}
38
39impl RowsStreamReader {
40 pub(crate) fn new(schema: Schema, upstream: Streaming<ReadRowsResponse>) -> Self {
41 Self { schema, upstream }
42 }
43
44 #[cfg(feature = "arrow")]
46 pub async fn into_arrow_reader(self) -> Result<DefaultArrowStreamReader, Error> {
47 let mut serialized_arrow_stream = self
48 .upstream
49 .map_err(|e| e.into())
50 .and_then(|resp| {
51 let ReadRowsResponse { rows, .. } = resp;
52 let out =
53 rows.ok_or(Error::invalid("no rows received"))
54 .and_then(|rows| match rows {
55 Rows::ArrowRecordBatch(ArrowRecordBatch {
56 serialized_record_batch,
57 ..
58 }) => Ok(serialized_record_batch),
59 _ => {
60 let err = Error::invalid("expected arrow record batch");
61 Err(err)
62 }
63 });
64 ready(out)
65 })
66 .boxed();
67
68 let serialized_schema = match self.schema {
69 Schema::ArrowSchema(ArrowSchema { serialized_schema }) => serialized_schema,
70 _ => return Err(Error::invalid("expected arrow schema")),
71 };
72
73 let mut buf = Vec::new();
74 buf.extend(strip_continuation_bytes(serialized_schema.as_slice())?);
75
76 while let Some(msg) = serialized_arrow_stream.next().await {
77 let msg = msg?;
78 let body = strip_continuation_bytes(msg.as_slice())?;
79 buf.extend(body);
80 }
81
82 buf.extend(&[0u8; 4]);
85
86 let reader = ArrowStreamReader::try_new(Cursor::new(buf))?;
87
88 Ok(reader)
89 }
90}