gcloud_bigquery/storage_write/
mod.rs1use std::collections::HashMap;
2
3use arrow::error::ArrowError;
4use arrow::ipc::writer::{
5 write_message, CompressionContext, DictionaryTracker, EncodedData, IpcDataGenerator, IpcWriteOptions,
6};
7use arrow::record_batch::RecordBatch;
8use google_cloud_gax::grpc::codegen::tokio_stream::Stream;
9use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{ArrowData, ProtoData, Rows};
10use google_cloud_googleapis::cloud::bigquery::storage::v1::{
11 AppendRowsRequest, ArrowRecordBatch, ArrowSchema, ProtoRows, ProtoSchema,
12};
13use prost_types::DescriptorProto;
14
15mod flow;
16pub mod stream;
17
18enum Payload {
19 Proto {
20 schema: DescriptorProto,
21 rows: Vec<Vec<u8>>,
22 },
23 Arrow {
24 serialized_schema: Vec<u8>,
25 serialized_record_batch: Vec<u8>,
26 },
27}
28
29pub struct AppendRowsRequestBuilder {
30 offset: Option<i64>,
31 trace_id: Option<String>,
32 missing_value_interpretations: Option<HashMap<String, i32>>,
33 default_missing_value_interpretation: Option<i32>,
34 payload: Payload,
35}
36
37impl AppendRowsRequestBuilder {
38 pub fn new(schema: DescriptorProto, data: Vec<Vec<u8>>) -> Self {
39 Self::with_payload(Payload::Proto { schema, rows: data })
40 }
41
42 pub fn new_arrow(serialized_schema: Vec<u8>, serialized_record_batch: Vec<u8>) -> Self {
43 Self::with_payload(Payload::Arrow {
44 serialized_schema,
45 serialized_record_batch,
46 })
47 }
48
49 pub fn from_record_batch(batch: &RecordBatch) -> Result<Self, ArrowError> {
50 let options = IpcWriteOptions::default();
51 let generator = IpcDataGenerator::default();
52 let mut dict_tracker = DictionaryTracker::new(true);
53 let mut compression = CompressionContext::default();
54
55 let schema_encoded =
56 generator.schema_to_bytes_with_dictionary_tracker(&batch.schema(), &mut dict_tracker, &options);
57 let serialized_schema = encoded_to_bytes(vec![schema_encoded], &options)?;
58
59 let (dict_encoded, batch_encoded) = generator.encode(batch, &mut dict_tracker, &options, &mut compression)?;
60 let mut encoded = dict_encoded;
61 encoded.push(batch_encoded);
62 let serialized_record_batch = encoded_to_bytes(encoded, &options)?;
63
64 Ok(Self::new_arrow(serialized_schema, serialized_record_batch))
65 }
66
67 fn with_payload(payload: Payload) -> Self {
68 Self {
69 offset: None,
70 trace_id: None,
71 missing_value_interpretations: None,
72 default_missing_value_interpretation: None,
73 payload,
74 }
75 }
76
77 pub fn with_offset(mut self, offset: i64) -> Self {
78 self.offset = Some(offset);
79 self
80 }
81
82 pub fn with_trace_id(mut self, trace_id: String) -> Self {
83 self.trace_id = Some(trace_id);
84 self
85 }
86
87 pub fn with_missing_value_interpretations(mut self, missing_value_interpretations: HashMap<String, i32>) -> Self {
88 self.missing_value_interpretations = Some(missing_value_interpretations);
89 self
90 }
91
92 pub fn with_default_missing_value_interpretation(mut self, default_missing_value_interpretation: i32) -> Self {
93 self.default_missing_value_interpretation = Some(default_missing_value_interpretation);
94 self
95 }
96
97 pub(crate) fn build(self, stream: &str) -> AppendRowsRequest {
98 let rows = match self.payload {
99 Payload::Proto { schema, rows } => Rows::ProtoRows(ProtoData {
100 writer_schema: Some(ProtoSchema {
101 proto_descriptor: Some(schema),
102 }),
103 rows: Some(ProtoRows { serialized_rows: rows }),
104 }),
105 Payload::Arrow {
106 serialized_schema,
107 serialized_record_batch,
108 } => Rows::ArrowRows(ArrowData {
109 writer_schema: Some(ArrowSchema { serialized_schema }),
110 rows: Some(ArrowRecordBatch {
111 serialized_record_batch,
112 #[allow(deprecated)]
113 row_count: 0,
114 }),
115 }),
116 };
117 AppendRowsRequest {
118 write_stream: stream.to_string(),
119 offset: self.offset,
120 trace_id: self.trace_id.unwrap_or_default(),
121 missing_value_interpretations: self.missing_value_interpretations.unwrap_or_default(),
122 default_missing_value_interpretation: self.default_missing_value_interpretation.unwrap_or(0),
123 rows: Some(rows),
124 }
125 }
126}
127
128fn encoded_to_bytes(messages: Vec<EncodedData>, options: &IpcWriteOptions) -> Result<Vec<u8>, ArrowError> {
129 let mut buf = Vec::new();
130 for message in messages {
131 write_message(&mut buf, message, options)?;
132 }
133 Ok(buf)
134}
135
136pub fn into_streaming_request(rows: Vec<AppendRowsRequest>) -> impl Stream<Item = AppendRowsRequest> {
137 async_stream::stream! {
138 for row in rows {
139 yield row;
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use std::io::{BufReader, Cursor};
147 use std::sync::Arc;
148
149 use arrow::array::{Int64Array, StringArray};
150 use arrow::datatypes::{DataType, Field, Schema};
151 use arrow::ipc::reader::StreamReader;
152 use arrow::record_batch::RecordBatch;
153 use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::Rows;
154
155 use super::AppendRowsRequestBuilder;
156
157 fn sample_batch() -> RecordBatch {
158 let schema = Arc::new(Schema::new(vec![
159 Field::new("id", DataType::Int64, false),
160 Field::new("name", DataType::Utf8, false),
161 ]));
162 let ids = Arc::new(Int64Array::from(vec![1, 2, 3]));
163 let names = Arc::new(StringArray::from(vec!["a", "b", "c"]));
164 RecordBatch::try_new(schema, vec![ids, names]).unwrap()
165 }
166
167 #[test]
168 fn from_record_batch_emits_arrow_rows_and_round_trips() {
169 let batch = sample_batch();
170 let expected_rows = batch.num_rows();
171
172 let builder = AppendRowsRequestBuilder::from_record_batch(&batch).unwrap();
173 let request = builder.build("projects/p/datasets/d/tables/t/streams/_default");
174
175 let Rows::ArrowRows(arrow_data) = request.rows.expect("rows set") else {
176 panic!("expected Arrow rows variant");
177 };
178 let schema_bytes = arrow_data.writer_schema.expect("writer_schema").serialized_schema;
179 let batch_bytes = arrow_data.rows.expect("rows").serialized_record_batch;
180 assert!(!schema_bytes.is_empty());
181 assert!(!batch_bytes.is_empty());
182
183 let mut combined = schema_bytes;
185 combined.extend_from_slice(&batch_bytes);
186 let reader = StreamReader::try_new(BufReader::new(Cursor::new(combined)), None).unwrap();
187 let decoded: Vec<RecordBatch> = reader.collect::<Result<_, _>>().unwrap();
188 assert_eq!(decoded.len(), 1);
189 assert_eq!(decoded[0].num_rows(), expected_rows);
190 assert_eq!(decoded[0].num_columns(), 2);
191 }
192}