gcp_bigquery_client/
storage.rs

1//! Manage BigQuery dataset.
2use std::{collections::HashMap, convert::TryInto, fmt::Display, sync::Arc};
3
4use prost::Message;
5use prost_types::{
6    field_descriptor_proto::{Label, Type},
7    DescriptorProto, FieldDescriptorProto,
8};
9use tonic::{
10    transport::{Channel, ClientTlsConfig},
11    Request, Streaming,
12};
13
14use crate::google::cloud::bigquery::storage::v1::{GetWriteStreamRequest, WriteStream, WriteStreamView};
15use crate::{
16    auth::Authenticator,
17    error::BQError,
18    google::cloud::bigquery::storage::v1::{
19        append_rows_request::{self, MissingValueInterpretation, ProtoData},
20        big_query_write_client::BigQueryWriteClient,
21        AppendRowsRequest, AppendRowsResponse, ProtoSchema,
22    },
23    BIG_QUERY_V2_URL,
24};
25
26static BIG_QUERY_STORAGE_API_URL: &str = "https://bigquerystorage.googleapis.com";
27// Service Name
28static BIGQUERY_STORAGE_API_DOMAIN: &str = "bigquerystorage.googleapis.com";
29
30/// Protobuf column type
31#[derive(Clone, Copy)]
32pub enum ColumnType {
33    Double,
34    Float,
35    Int64,
36    Uint64,
37    Int32,
38    Fixed64,
39    Fixed32,
40    Bool,
41    String,
42    Bytes,
43    Uint32,
44    Sfixed32,
45    Sfixed64,
46    Sint32,
47    Sint64,
48}
49
50impl From<ColumnType> for Type {
51    fn from(value: ColumnType) -> Self {
52        match value {
53            ColumnType::Double => Type::Double,
54            ColumnType::Float => Type::Float,
55            ColumnType::Int64 => Type::Int64,
56            ColumnType::Uint64 => Type::Uint64,
57            ColumnType::Int32 => Type::Int32,
58            ColumnType::Fixed64 => Type::Fixed64,
59            ColumnType::Fixed32 => Type::Fixed32,
60            ColumnType::Bool => Type::Bool,
61            ColumnType::String => Type::String,
62            ColumnType::Bytes => Type::Bytes,
63            ColumnType::Uint32 => Type::Uint32,
64            ColumnType::Sfixed32 => Type::Sfixed32,
65            ColumnType::Sfixed64 => Type::Sfixed64,
66            ColumnType::Sint32 => Type::Sint32,
67            ColumnType::Sint64 => Type::Sfixed64,
68        }
69    }
70}
71
72/// Column mode
73#[derive(Clone, Copy)]
74pub enum ColumnMode {
75    Nullable,
76    Required,
77    Repeated,
78}
79
80impl From<ColumnMode> for Label {
81    fn from(value: ColumnMode) -> Self {
82        match value {
83            ColumnMode::Nullable => Label::Optional,
84            ColumnMode::Required => Label::Required,
85            ColumnMode::Repeated => Label::Repeated,
86        }
87    }
88}
89
90/// A struct to describe the schema of a field in protobuf
91pub struct FieldDescriptor {
92    /// Field numbers starting from 1. Each subsequence field should be incremented by 1.
93    pub number: u32,
94
95    /// Field name
96    pub name: String,
97
98    /// Field type
99    pub typ: ColumnType,
100
101    /// Field mode
102    pub mode: ColumnMode,
103}
104
105/// A struct to describe the schema of a table in protobuf
106pub struct TableDescriptor {
107    /// Descriptors of all the fields
108    pub field_descriptors: Vec<FieldDescriptor>,
109}
110
111/// A struct representing a stream name
112pub struct StreamName {
113    /// Name of the project
114    project: String,
115
116    /// Name of the dataset
117    dataset: String,
118
119    /// Name of the table
120    table: String,
121
122    /// Name of the stream
123    stream: String,
124}
125
126impl StreamName {
127    pub fn new(project: String, dataset: String, table: String, stream: String) -> StreamName {
128        StreamName {
129            project,
130            dataset,
131            table,
132            stream,
133        }
134    }
135
136    pub fn new_default(project: String, dataset: String, table: String) -> StreamName {
137        StreamName {
138            project,
139            dataset,
140            table,
141            stream: "_default".to_string(),
142        }
143    }
144}
145
146impl Display for StreamName {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        let StreamName {
149            project,
150            dataset,
151            table,
152            stream,
153        } = self;
154        f.write_fmt(format_args!(
155            "projects/{project}/datasets/{dataset}/tables/{table}/streams/{stream}"
156        ))
157    }
158}
159
160/// A dataset API handler.
161#[derive(Clone)]
162pub struct StorageApi {
163    write_client: BigQueryWriteClient<Channel>,
164    auth: Arc<dyn Authenticator>,
165    base_url: String,
166}
167
168impl StorageApi {
169    pub(crate) fn new(write_client: BigQueryWriteClient<Channel>, auth: Arc<dyn Authenticator>) -> Self {
170        Self {
171            write_client,
172            auth,
173            base_url: BIG_QUERY_V2_URL.to_string(),
174        }
175    }
176
177    pub(crate) async fn new_write_client() -> Result<BigQueryWriteClient<Channel>, BQError> {
178        // Since Tonic 0.12.0, TLS root certificates are no longer implicit.
179        // We need to specify them explicitly.
180        // See: https://github.com/hyperium/tonic/pull/1731
181        let tls_config = ClientTlsConfig::new()
182            .domain_name(BIGQUERY_STORAGE_API_DOMAIN)
183            .with_native_roots();
184        let channel = Channel::from_static(BIG_QUERY_STORAGE_API_URL)
185            .tls_config(tls_config)?
186            .connect()
187            .await?;
188        let write_client = BigQueryWriteClient::new(channel);
189
190        Ok(write_client)
191    }
192
193    pub(crate) fn with_base_url(&mut self, base_url: String) -> &mut Self {
194        self.base_url = base_url;
195        self
196    }
197
198    /// Append rows to a table via the BigQuery Storage Write API.
199    pub async fn append_rows(
200        &mut self,
201        stream_name: &StreamName,
202        rows: append_rows_request::Rows,
203        trace_id: String,
204    ) -> Result<Streaming<AppendRowsResponse>, BQError> {
205        let write_stream = stream_name.to_string();
206
207        let append_rows_request = AppendRowsRequest {
208            write_stream,
209            offset: None,
210            trace_id,
211            missing_value_interpretations: HashMap::new(),
212            default_missing_value_interpretation: MissingValueInterpretation::Unspecified.into(),
213            rows: Some(rows),
214        };
215
216        let req = self
217            .new_authorized_request(tokio_stream::iter(vec![append_rows_request]))
218            .await?;
219
220        let response = self.write_client.append_rows(req).await?;
221
222        let streaming = response.into_inner();
223
224        Ok(streaming)
225    }
226
227    /// This function encodes the `rows` slice into a protobuf message
228    /// while ensuring that the total size of the encoded rows does
229    /// not exceed the `max_size` argument. The encoded rows are returned
230    /// in the first value of the tuple returned by this function.
231    ///
232    /// Note that it is possible that not all the rows in the `rows` slice
233    /// were encoded due to the `max_size` limit.  The callers can find
234    /// out how many rows were processed by looking at the second value in
235    /// the tuple returned by this function. If the number of rows processed
236    /// is less than the number of rows in the `rows` slice, then the caller
237    /// can call this function again with the rows remaing at the end of the
238    /// slice to encode them.
239    ///
240    /// The AppendRows API has a payload size limit of 10MB. Some of the
241    /// space in the 10MB limit is used by the request metadata, so the
242    /// `max_size` argument should be set to a value less than 10MB. 9MB
243    /// is a good value to use for the `max_size` argument.
244    pub fn create_rows<M: Message>(
245        table_descriptor: &TableDescriptor,
246        rows: &[M],
247        max_size_bytes: usize,
248    ) -> (append_rows_request::Rows, usize) {
249        let field_descriptors = table_descriptor
250            .field_descriptors
251            .iter()
252            .map(|fd| {
253                let typ: Type = fd.typ.into();
254                let label: Label = fd.mode.into();
255                FieldDescriptorProto {
256                    name: Some(fd.name.clone()),
257                    number: Some(fd.number as i32),
258                    label: Some(label.into()),
259                    r#type: Some(typ.into()),
260                    type_name: None,
261                    extendee: None,
262                    default_value: None,
263                    oneof_index: None,
264                    json_name: None,
265                    options: None,
266                    proto3_optional: None,
267                }
268            })
269            .collect();
270        let proto_descriptor = DescriptorProto {
271            name: Some("table_schema".to_string()),
272            field: field_descriptors,
273            extension: vec![],
274            nested_type: vec![],
275            enum_type: vec![],
276            extension_range: vec![],
277            oneof_decl: vec![],
278            options: None,
279            reserved_range: vec![],
280            reserved_name: vec![],
281        };
282        let proto_schema = ProtoSchema {
283            proto_descriptor: Some(proto_descriptor),
284        };
285
286        let mut serialized_rows = Vec::new();
287        let mut total_size = 0;
288
289        for row in rows {
290            let encoded_row = row.encode_to_vec();
291            let current_size = encoded_row.len();
292
293            if total_size + current_size > max_size_bytes {
294                break;
295            }
296
297            serialized_rows.push(encoded_row);
298            total_size += current_size;
299        }
300
301        let num_rows_processed = serialized_rows.len();
302
303        let proto_rows = crate::google::cloud::bigquery::storage::v1::ProtoRows { serialized_rows };
304
305        let proto_data = ProtoData {
306            writer_schema: Some(proto_schema),
307            rows: Some(proto_rows),
308        };
309        (append_rows_request::Rows::ProtoRows(proto_data), num_rows_processed)
310    }
311
312    async fn new_authorized_request<D>(&self, t: D) -> Result<Request<D>, BQError> {
313        let access_token = self.auth.access_token().await?;
314        let bearer_token = format!("Bearer {access_token}");
315        let bearer_value = bearer_token.as_str().try_into()?;
316        let mut req = Request::new(t);
317        let meta = req.metadata_mut();
318        meta.insert("authorization", bearer_value);
319        Ok(req)
320    }
321
322    pub async fn get_write_stream(
323        &mut self,
324        stream_name: &StreamName,
325        view: WriteStreamView,
326    ) -> Result<WriteStream, BQError> {
327        let get_write_stream_request = GetWriteStreamRequest {
328            name: stream_name.to_string(),
329            view: view.into(),
330        };
331
332        let req = self.new_authorized_request(get_write_stream_request).await?;
333
334        let response = self.write_client.get_write_stream(req).await?;
335        let write_stream = response.into_inner();
336
337        Ok(write_stream)
338    }
339}
340
341#[cfg(test)]
342pub mod test {
343    use crate::model::dataset::Dataset;
344    use crate::model::field_type::FieldType;
345    use crate::model::table::Table;
346    use crate::model::table_field_schema::TableFieldSchema;
347    use crate::model::table_schema::TableSchema;
348    use crate::storage::{ColumnMode, ColumnType, FieldDescriptor, StorageApi, StreamName, TableDescriptor};
349    use crate::{env_vars, Client};
350    use prost::Message;
351    use std::time::{Duration, SystemTime};
352    use tokio_stream::StreamExt;
353
354    #[derive(Clone, PartialEq, Message)]
355    struct Actor {
356        #[prost(int32, tag = "1")]
357        actor_id: i32,
358
359        #[prost(string, tag = "2")]
360        first_name: String,
361
362        #[prost(string, tag = "3")]
363        last_name: String,
364
365        #[prost(string, tag = "4")]
366        last_update: String,
367    }
368
369    #[tokio::test]
370    async fn test_append_rows() -> Result<(), Box<dyn std::error::Error>> {
371        let (ref project_id, ref dataset_id, ref table_id, ref sa_key) = env_vars();
372        let dataset_id = &format!("{dataset_id}_storage");
373
374        let mut client = Client::from_service_account_key_file(sa_key).await?;
375
376        // Delete the dataset if needed
377        client.dataset().delete_if_exists(project_id, dataset_id, true).await;
378
379        // Create dataset
380        let created_dataset = client.dataset().create(Dataset::new(project_id, dataset_id)).await?;
381        assert_eq!(created_dataset.id, Some(format!("{project_id}:{dataset_id}")));
382
383        // Create table
384        let table = Table::new(
385            project_id,
386            dataset_id,
387            table_id,
388            TableSchema::new(vec![
389                TableFieldSchema::new("actor_id", FieldType::Int64),
390                TableFieldSchema::new("first_name", FieldType::String),
391                TableFieldSchema::new("last_name", FieldType::String),
392                TableFieldSchema::new("last_update", FieldType::Timestamp),
393            ]),
394        );
395        let created_table = client
396            .table()
397            .create(
398                table
399                    .description("A table used for unit tests")
400                    .label("owner", "me")
401                    .label("env", "prod")
402                    .expiration_time(SystemTime::now() + Duration::from_secs(3600)),
403            )
404            .await?;
405        assert_eq!(created_table.table_reference.table_id, table_id.to_string());
406
407        // let (ref project_id, ref dataset_id, ref table_id, ref gcp_sa_key) = env_vars();
408        //
409        // let mut client = crate::Client::from_service_account_key_file(gcp_sa_key).await?;
410
411        let field_descriptors = vec![
412            FieldDescriptor {
413                name: "actor_id".to_string(),
414                number: 1,
415                typ: ColumnType::Int64,
416                mode: ColumnMode::Nullable,
417            },
418            FieldDescriptor {
419                name: "first_name".to_string(),
420                number: 2,
421                typ: ColumnType::String,
422                mode: ColumnMode::Nullable,
423            },
424            FieldDescriptor {
425                name: "last_name".to_string(),
426                number: 3,
427                typ: ColumnType::String,
428                mode: ColumnMode::Nullable,
429            },
430            FieldDescriptor {
431                name: "last_update".to_string(),
432                number: 4,
433                typ: ColumnType::String,
434                mode: ColumnMode::Nullable,
435            },
436        ];
437        let table_descriptor = TableDescriptor { field_descriptors };
438
439        let actor1 = Actor {
440            actor_id: 1,
441            first_name: "John".to_string(),
442            last_name: "Doe".to_string(),
443            last_update: "2007-02-15 09:34:33 UTC".to_string(),
444        };
445
446        let actor2 = Actor {
447            actor_id: 2,
448            first_name: "Jane".to_string(),
449            last_name: "Doe".to_string(),
450            last_update: "2008-02-15 09:34:33 UTC".to_string(),
451        };
452
453        let stream_name = StreamName::new_default(project_id.clone(), dataset_id.clone(), table_id.clone());
454        let trace_id = "test_client".to_string();
455
456        let rows: &[Actor] = &[actor1, actor2];
457
458        let max_size = 9 * 1024 * 1024; // 9 MB
459        let num_append_rows_calls = call_append_rows(
460            &mut client,
461            &table_descriptor,
462            &stream_name,
463            trace_id.clone(),
464            rows,
465            max_size,
466        )
467        .await?;
468        assert_eq!(num_append_rows_calls, 1);
469
470        // It was found after experimenting that one row in this test encodes to about 38 bytes
471        // We artificially limit the size of the rows to test that the loop processes all the rows
472        let max_size = 50; // 50 bytes
473        let num_append_rows_calls =
474            call_append_rows(&mut client, &table_descriptor, &stream_name, trace_id, rows, max_size).await?;
475        assert_eq!(num_append_rows_calls, 2);
476
477        Ok(())
478    }
479
480    async fn call_append_rows(
481        client: &mut Client,
482        table_descriptor: &TableDescriptor,
483        stream_name: &StreamName,
484        trace_id: String,
485        mut rows: &[Actor],
486        max_size: usize,
487    ) -> Result<u8, Box<dyn std::error::Error>> {
488        // This loop is needed because the AppendRows API has a payload size limit of 10MB and the create_rows
489        // function may not process all the rows in the rows slice due to the 10MB limit. Even though in this
490        // example we are only sending two rows (which won't breach the 10MB limit), in a real-world scenario,
491        // we may have to send more rows and the loop will be needed to process all the rows.
492        let mut num_append_rows_calls = 0;
493        loop {
494            let (encoded_rows, num_processed) = StorageApi::create_rows(table_descriptor, rows, max_size);
495            let mut streaming = client
496                .storage_mut()
497                .append_rows(stream_name, encoded_rows, trace_id.clone())
498                .await?;
499
500            num_append_rows_calls += 1;
501
502            while let Some(resp) = streaming.next().await {
503                let resp = resp?;
504                println!("response: {resp:#?}");
505            }
506
507            // All the rows have been processed
508            if num_processed == rows.len() {
509                break;
510            }
511
512            // Process the remaining rows
513            rows = &rows[num_processed..];
514        }
515
516        Ok(num_append_rows_calls)
517    }
518}