Skip to main content

gcloud_bigquery/storage_write/
mod.rs

1use std::collections::HashMap;
2
3use arrow::error::ArrowError;
4use arrow::ipc::writer::{
5    write_message, CompressionContext, DictionaryTracker, EncodedData, IpcDataGenerator, IpcWriteOptions,
6};
7use arrow::record_batch::RecordBatch;
8use google_cloud_gax::grpc::codegen::tokio_stream::Stream;
9use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{ArrowData, ProtoData, Rows};
10use google_cloud_googleapis::cloud::bigquery::storage::v1::{
11    AppendRowsRequest, ArrowRecordBatch, ArrowSchema, ProtoRows, ProtoSchema,
12};
13use prost_types::DescriptorProto;
14
15mod flow;
16pub mod stream;
17
18enum Payload {
19    Proto {
20        schema: DescriptorProto,
21        rows: Vec<Vec<u8>>,
22    },
23    Arrow {
24        serialized_schema: Vec<u8>,
25        serialized_record_batch: Vec<u8>,
26    },
27}
28
29pub struct AppendRowsRequestBuilder {
30    offset: Option<i64>,
31    trace_id: Option<String>,
32    missing_value_interpretations: Option<HashMap<String, i32>>,
33    default_missing_value_interpretation: Option<i32>,
34    payload: Payload,
35}
36
37impl AppendRowsRequestBuilder {
38    pub fn new(schema: DescriptorProto, data: Vec<Vec<u8>>) -> Self {
39        Self::with_payload(Payload::Proto { schema, rows: data })
40    }
41
42    pub fn new_arrow(serialized_schema: Vec<u8>, serialized_record_batch: Vec<u8>) -> Self {
43        Self::with_payload(Payload::Arrow {
44            serialized_schema,
45            serialized_record_batch,
46        })
47    }
48
49    pub fn from_record_batch(batch: &RecordBatch) -> Result<Self, ArrowError> {
50        let options = IpcWriteOptions::default();
51        let generator = IpcDataGenerator::default();
52        let mut dict_tracker = DictionaryTracker::new(true);
53        let mut compression = CompressionContext::default();
54
55        let schema_encoded =
56            generator.schema_to_bytes_with_dictionary_tracker(&batch.schema(), &mut dict_tracker, &options);
57        let serialized_schema = encoded_to_bytes(vec![schema_encoded], &options)?;
58
59        let (dict_encoded, batch_encoded) = generator.encode(batch, &mut dict_tracker, &options, &mut compression)?;
60        let mut encoded = dict_encoded;
61        encoded.push(batch_encoded);
62        let serialized_record_batch = encoded_to_bytes(encoded, &options)?;
63
64        Ok(Self::new_arrow(serialized_schema, serialized_record_batch))
65    }
66
67    fn with_payload(payload: Payload) -> Self {
68        Self {
69            offset: None,
70            trace_id: None,
71            missing_value_interpretations: None,
72            default_missing_value_interpretation: None,
73            payload,
74        }
75    }
76
77    pub fn with_offset(mut self, offset: i64) -> Self {
78        self.offset = Some(offset);
79        self
80    }
81
82    pub fn with_trace_id(mut self, trace_id: String) -> Self {
83        self.trace_id = Some(trace_id);
84        self
85    }
86
87    pub fn with_missing_value_interpretations(mut self, missing_value_interpretations: HashMap<String, i32>) -> Self {
88        self.missing_value_interpretations = Some(missing_value_interpretations);
89        self
90    }
91
92    pub fn with_default_missing_value_interpretation(mut self, default_missing_value_interpretation: i32) -> Self {
93        self.default_missing_value_interpretation = Some(default_missing_value_interpretation);
94        self
95    }
96
97    pub(crate) fn build(self, stream: &str) -> AppendRowsRequest {
98        let rows = match self.payload {
99            Payload::Proto { schema, rows } => Rows::ProtoRows(ProtoData {
100                writer_schema: Some(ProtoSchema {
101                    proto_descriptor: Some(schema),
102                }),
103                rows: Some(ProtoRows { serialized_rows: rows }),
104            }),
105            Payload::Arrow {
106                serialized_schema,
107                serialized_record_batch,
108            } => Rows::ArrowRows(ArrowData {
109                writer_schema: Some(ArrowSchema { serialized_schema }),
110                rows: Some(ArrowRecordBatch {
111                    serialized_record_batch,
112                    #[allow(deprecated)]
113                    row_count: 0,
114                }),
115            }),
116        };
117        AppendRowsRequest {
118            write_stream: stream.to_string(),
119            offset: self.offset,
120            trace_id: self.trace_id.unwrap_or_default(),
121            missing_value_interpretations: self.missing_value_interpretations.unwrap_or_default(),
122            default_missing_value_interpretation: self.default_missing_value_interpretation.unwrap_or(0),
123            rows: Some(rows),
124        }
125    }
126}
127
128fn encoded_to_bytes(messages: Vec<EncodedData>, options: &IpcWriteOptions) -> Result<Vec<u8>, ArrowError> {
129    let mut buf = Vec::new();
130    for message in messages {
131        write_message(&mut buf, message, options)?;
132    }
133    Ok(buf)
134}
135
136pub fn into_streaming_request(rows: Vec<AppendRowsRequest>) -> impl Stream<Item = AppendRowsRequest> {
137    async_stream::stream! {
138        for row in rows {
139            yield row;
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use std::io::{BufReader, Cursor};
147    use std::sync::Arc;
148
149    use arrow::array::{Int64Array, StringArray};
150    use arrow::datatypes::{DataType, Field, Schema};
151    use arrow::ipc::reader::StreamReader;
152    use arrow::record_batch::RecordBatch;
153    use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::Rows;
154
155    use super::AppendRowsRequestBuilder;
156
157    fn sample_batch() -> RecordBatch {
158        let schema = Arc::new(Schema::new(vec![
159            Field::new("id", DataType::Int64, false),
160            Field::new("name", DataType::Utf8, false),
161        ]));
162        let ids = Arc::new(Int64Array::from(vec![1, 2, 3]));
163        let names = Arc::new(StringArray::from(vec!["a", "b", "c"]));
164        RecordBatch::try_new(schema, vec![ids, names]).unwrap()
165    }
166
167    #[test]
168    fn from_record_batch_emits_arrow_rows_and_round_trips() {
169        let batch = sample_batch();
170        let expected_rows = batch.num_rows();
171
172        let builder = AppendRowsRequestBuilder::from_record_batch(&batch).unwrap();
173        let request = builder.build("projects/p/datasets/d/tables/t/streams/_default");
174
175        let Rows::ArrowRows(arrow_data) = request.rows.expect("rows set") else {
176            panic!("expected Arrow rows variant");
177        };
178        let schema_bytes = arrow_data.writer_schema.expect("writer_schema").serialized_schema;
179        let batch_bytes = arrow_data.rows.expect("rows").serialized_record_batch;
180        assert!(!schema_bytes.is_empty());
181        assert!(!batch_bytes.is_empty());
182
183        // Mirror storage.rs: concat schema + batch and decode with StreamReader.
184        let mut combined = schema_bytes;
185        combined.extend_from_slice(&batch_bytes);
186        let reader = StreamReader::try_new(BufReader::new(Cursor::new(combined)), None).unwrap();
187        let decoded: Vec<RecordBatch> = reader.collect::<Result<_, _>>().unwrap();
188        assert_eq!(decoded.len(), 1);
189        assert_eq!(decoded[0].num_rows(), expected_rows);
190        assert_eq!(decoded[0].num_columns(), 2);
191    }
192}