1use std::{collections::HashMap, convert::TryInto, fmt::Display, sync::Arc};
3
4use prost::Message;
5use prost_types::{
6 field_descriptor_proto::{Label, Type},
7 DescriptorProto, FieldDescriptorProto,
8};
9use tonic::{
10 transport::{Channel, ClientTlsConfig},
11 Request, Streaming,
12};
13
14use crate::google::cloud::bigquery::storage::v1::{GetWriteStreamRequest, WriteStream, WriteStreamView};
15use crate::{
16 auth::Authenticator,
17 error::BQError,
18 google::cloud::bigquery::storage::v1::{
19 append_rows_request::{self, MissingValueInterpretation, ProtoData},
20 big_query_write_client::BigQueryWriteClient,
21 AppendRowsRequest, AppendRowsResponse, ProtoSchema,
22 },
23 BIG_QUERY_V2_URL,
24};
25
26static BIG_QUERY_STORAGE_API_URL: &str = "https://bigquerystorage.googleapis.com";
27static BIGQUERY_STORAGE_API_DOMAIN: &str = "bigquerystorage.googleapis.com";
29
30#[derive(Clone, Copy)]
32pub enum ColumnType {
33 Double,
34 Float,
35 Int64,
36 Uint64,
37 Int32,
38 Fixed64,
39 Fixed32,
40 Bool,
41 String,
42 Bytes,
43 Uint32,
44 Sfixed32,
45 Sfixed64,
46 Sint32,
47 Sint64,
48}
49
50impl From<ColumnType> for Type {
51 fn from(value: ColumnType) -> Self {
52 match value {
53 ColumnType::Double => Type::Double,
54 ColumnType::Float => Type::Float,
55 ColumnType::Int64 => Type::Int64,
56 ColumnType::Uint64 => Type::Uint64,
57 ColumnType::Int32 => Type::Int32,
58 ColumnType::Fixed64 => Type::Fixed64,
59 ColumnType::Fixed32 => Type::Fixed32,
60 ColumnType::Bool => Type::Bool,
61 ColumnType::String => Type::String,
62 ColumnType::Bytes => Type::Bytes,
63 ColumnType::Uint32 => Type::Uint32,
64 ColumnType::Sfixed32 => Type::Sfixed32,
65 ColumnType::Sfixed64 => Type::Sfixed64,
66 ColumnType::Sint32 => Type::Sint32,
67 ColumnType::Sint64 => Type::Sfixed64,
68 }
69 }
70}
71
72#[derive(Clone, Copy)]
74pub enum ColumnMode {
75 Nullable,
76 Required,
77 Repeated,
78}
79
80impl From<ColumnMode> for Label {
81 fn from(value: ColumnMode) -> Self {
82 match value {
83 ColumnMode::Nullable => Label::Optional,
84 ColumnMode::Required => Label::Required,
85 ColumnMode::Repeated => Label::Repeated,
86 }
87 }
88}
89
90pub struct FieldDescriptor {
92 pub number: u32,
94
95 pub name: String,
97
98 pub typ: ColumnType,
100
101 pub mode: ColumnMode,
103}
104
105pub struct TableDescriptor {
107 pub field_descriptors: Vec<FieldDescriptor>,
109}
110
111pub struct StreamName {
113 project: String,
115
116 dataset: String,
118
119 table: String,
121
122 stream: String,
124}
125
126impl StreamName {
127 pub fn new(project: String, dataset: String, table: String, stream: String) -> StreamName {
128 StreamName {
129 project,
130 dataset,
131 table,
132 stream,
133 }
134 }
135
136 pub fn new_default(project: String, dataset: String, table: String) -> StreamName {
137 StreamName {
138 project,
139 dataset,
140 table,
141 stream: "_default".to_string(),
142 }
143 }
144}
145
146impl Display for StreamName {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 let StreamName {
149 project,
150 dataset,
151 table,
152 stream,
153 } = self;
154 f.write_fmt(format_args!(
155 "projects/{project}/datasets/{dataset}/tables/{table}/streams/{stream}"
156 ))
157 }
158}
159
160#[derive(Clone)]
162pub struct StorageApi {
163 write_client: BigQueryWriteClient<Channel>,
164 auth: Arc<dyn Authenticator>,
165 base_url: String,
166}
167
168impl StorageApi {
169 pub(crate) fn new(write_client: BigQueryWriteClient<Channel>, auth: Arc<dyn Authenticator>) -> Self {
170 Self {
171 write_client,
172 auth,
173 base_url: BIG_QUERY_V2_URL.to_string(),
174 }
175 }
176
177 pub(crate) async fn new_write_client() -> Result<BigQueryWriteClient<Channel>, BQError> {
178 let tls_config = ClientTlsConfig::new()
182 .domain_name(BIGQUERY_STORAGE_API_DOMAIN)
183 .with_native_roots();
184 let channel = Channel::from_static(BIG_QUERY_STORAGE_API_URL)
185 .tls_config(tls_config)?
186 .connect()
187 .await?;
188 let write_client = BigQueryWriteClient::new(channel);
189
190 Ok(write_client)
191 }
192
193 pub(crate) fn with_base_url(&mut self, base_url: String) -> &mut Self {
194 self.base_url = base_url;
195 self
196 }
197
198 pub async fn append_rows(
200 &mut self,
201 stream_name: &StreamName,
202 rows: append_rows_request::Rows,
203 trace_id: String,
204 ) -> Result<Streaming<AppendRowsResponse>, BQError> {
205 let write_stream = stream_name.to_string();
206
207 let append_rows_request = AppendRowsRequest {
208 write_stream,
209 offset: None,
210 trace_id,
211 missing_value_interpretations: HashMap::new(),
212 default_missing_value_interpretation: MissingValueInterpretation::Unspecified.into(),
213 rows: Some(rows),
214 };
215
216 let req = self
217 .new_authorized_request(tokio_stream::iter(vec![append_rows_request]))
218 .await?;
219
220 let response = self.write_client.append_rows(req).await?;
221
222 let streaming = response.into_inner();
223
224 Ok(streaming)
225 }
226
227 pub fn create_rows<M: Message>(
245 table_descriptor: &TableDescriptor,
246 rows: &[M],
247 max_size_bytes: usize,
248 ) -> (append_rows_request::Rows, usize) {
249 let field_descriptors = table_descriptor
250 .field_descriptors
251 .iter()
252 .map(|fd| {
253 let typ: Type = fd.typ.into();
254 let label: Label = fd.mode.into();
255 FieldDescriptorProto {
256 name: Some(fd.name.clone()),
257 number: Some(fd.number as i32),
258 label: Some(label.into()),
259 r#type: Some(typ.into()),
260 type_name: None,
261 extendee: None,
262 default_value: None,
263 oneof_index: None,
264 json_name: None,
265 options: None,
266 proto3_optional: None,
267 }
268 })
269 .collect();
270 let proto_descriptor = DescriptorProto {
271 name: Some("table_schema".to_string()),
272 field: field_descriptors,
273 extension: vec![],
274 nested_type: vec![],
275 enum_type: vec![],
276 extension_range: vec![],
277 oneof_decl: vec![],
278 options: None,
279 reserved_range: vec![],
280 reserved_name: vec![],
281 };
282 let proto_schema = ProtoSchema {
283 proto_descriptor: Some(proto_descriptor),
284 };
285
286 let mut serialized_rows = Vec::new();
287 let mut total_size = 0;
288
289 for row in rows {
290 let encoded_row = row.encode_to_vec();
291 let current_size = encoded_row.len();
292
293 if total_size + current_size > max_size_bytes {
294 break;
295 }
296
297 serialized_rows.push(encoded_row);
298 total_size += current_size;
299 }
300
301 let num_rows_processed = serialized_rows.len();
302
303 let proto_rows = crate::google::cloud::bigquery::storage::v1::ProtoRows { serialized_rows };
304
305 let proto_data = ProtoData {
306 writer_schema: Some(proto_schema),
307 rows: Some(proto_rows),
308 };
309 (append_rows_request::Rows::ProtoRows(proto_data), num_rows_processed)
310 }
311
312 async fn new_authorized_request<D>(&self, t: D) -> Result<Request<D>, BQError> {
313 let access_token = self.auth.access_token().await?;
314 let bearer_token = format!("Bearer {access_token}");
315 let bearer_value = bearer_token.as_str().try_into()?;
316 let mut req = Request::new(t);
317 let meta = req.metadata_mut();
318 meta.insert("authorization", bearer_value);
319 Ok(req)
320 }
321
322 pub async fn get_write_stream(
323 &mut self,
324 stream_name: &StreamName,
325 view: WriteStreamView,
326 ) -> Result<WriteStream, BQError> {
327 let get_write_stream_request = GetWriteStreamRequest {
328 name: stream_name.to_string(),
329 view: view.into(),
330 };
331
332 let req = self.new_authorized_request(get_write_stream_request).await?;
333
334 let response = self.write_client.get_write_stream(req).await?;
335 let write_stream = response.into_inner();
336
337 Ok(write_stream)
338 }
339}
340
341#[cfg(test)]
342pub mod test {
343 use crate::model::dataset::Dataset;
344 use crate::model::field_type::FieldType;
345 use crate::model::table::Table;
346 use crate::model::table_field_schema::TableFieldSchema;
347 use crate::model::table_schema::TableSchema;
348 use crate::storage::{ColumnMode, ColumnType, FieldDescriptor, StorageApi, StreamName, TableDescriptor};
349 use crate::{env_vars, Client};
350 use prost::Message;
351 use std::time::{Duration, SystemTime};
352 use tokio_stream::StreamExt;
353
354 #[derive(Clone, PartialEq, Message)]
355 struct Actor {
356 #[prost(int32, tag = "1")]
357 actor_id: i32,
358
359 #[prost(string, tag = "2")]
360 first_name: String,
361
362 #[prost(string, tag = "3")]
363 last_name: String,
364
365 #[prost(string, tag = "4")]
366 last_update: String,
367 }
368
369 #[tokio::test]
370 async fn test_append_rows() -> Result<(), Box<dyn std::error::Error>> {
371 let (ref project_id, ref dataset_id, ref table_id, ref sa_key) = env_vars();
372 let dataset_id = &format!("{dataset_id}_storage");
373
374 let mut client = Client::from_service_account_key_file(sa_key).await?;
375
376 client.dataset().delete_if_exists(project_id, dataset_id, true).await;
378
379 let created_dataset = client.dataset().create(Dataset::new(project_id, dataset_id)).await?;
381 assert_eq!(created_dataset.id, Some(format!("{project_id}:{dataset_id}")));
382
383 let table = Table::new(
385 project_id,
386 dataset_id,
387 table_id,
388 TableSchema::new(vec![
389 TableFieldSchema::new("actor_id", FieldType::Int64),
390 TableFieldSchema::new("first_name", FieldType::String),
391 TableFieldSchema::new("last_name", FieldType::String),
392 TableFieldSchema::new("last_update", FieldType::Timestamp),
393 ]),
394 );
395 let created_table = client
396 .table()
397 .create(
398 table
399 .description("A table used for unit tests")
400 .label("owner", "me")
401 .label("env", "prod")
402 .expiration_time(SystemTime::now() + Duration::from_secs(3600)),
403 )
404 .await?;
405 assert_eq!(created_table.table_reference.table_id, table_id.to_string());
406
407 let field_descriptors = vec![
412 FieldDescriptor {
413 name: "actor_id".to_string(),
414 number: 1,
415 typ: ColumnType::Int64,
416 mode: ColumnMode::Nullable,
417 },
418 FieldDescriptor {
419 name: "first_name".to_string(),
420 number: 2,
421 typ: ColumnType::String,
422 mode: ColumnMode::Nullable,
423 },
424 FieldDescriptor {
425 name: "last_name".to_string(),
426 number: 3,
427 typ: ColumnType::String,
428 mode: ColumnMode::Nullable,
429 },
430 FieldDescriptor {
431 name: "last_update".to_string(),
432 number: 4,
433 typ: ColumnType::String,
434 mode: ColumnMode::Nullable,
435 },
436 ];
437 let table_descriptor = TableDescriptor { field_descriptors };
438
439 let actor1 = Actor {
440 actor_id: 1,
441 first_name: "John".to_string(),
442 last_name: "Doe".to_string(),
443 last_update: "2007-02-15 09:34:33 UTC".to_string(),
444 };
445
446 let actor2 = Actor {
447 actor_id: 2,
448 first_name: "Jane".to_string(),
449 last_name: "Doe".to_string(),
450 last_update: "2008-02-15 09:34:33 UTC".to_string(),
451 };
452
453 let stream_name = StreamName::new_default(project_id.clone(), dataset_id.clone(), table_id.clone());
454 let trace_id = "test_client".to_string();
455
456 let rows: &[Actor] = &[actor1, actor2];
457
458 let max_size = 9 * 1024 * 1024; let num_append_rows_calls = call_append_rows(
460 &mut client,
461 &table_descriptor,
462 &stream_name,
463 trace_id.clone(),
464 rows,
465 max_size,
466 )
467 .await?;
468 assert_eq!(num_append_rows_calls, 1);
469
470 let max_size = 50; let num_append_rows_calls =
474 call_append_rows(&mut client, &table_descriptor, &stream_name, trace_id, rows, max_size).await?;
475 assert_eq!(num_append_rows_calls, 2);
476
477 Ok(())
478 }
479
480 async fn call_append_rows(
481 client: &mut Client,
482 table_descriptor: &TableDescriptor,
483 stream_name: &StreamName,
484 trace_id: String,
485 mut rows: &[Actor],
486 max_size: usize,
487 ) -> Result<u8, Box<dyn std::error::Error>> {
488 let mut num_append_rows_calls = 0;
493 loop {
494 let (encoded_rows, num_processed) = StorageApi::create_rows(table_descriptor, rows, max_size);
495 let mut streaming = client
496 .storage_mut()
497 .append_rows(stream_name, encoded_rows, trace_id.clone())
498 .await?;
499
500 num_append_rows_calls += 1;
501
502 while let Some(resp) = streaming.next().await {
503 let resp = resp?;
504 println!("response: {resp:#?}");
505 }
506
507 if num_processed == rows.len() {
509 break;
510 }
511
512 rows = &rows[num_processed..];
514 }
515
516 Ok(num_append_rows_calls)
517 }
518}