1use deadpool::managed::{Manager, Object, Pool, QueueMode};
7use futures::stream::Stream;
8use futures::StreamExt;
9use pin_project::pin_project;
10use prost::Message;
11use prost_types::{
12 field_descriptor_proto::{Label, Type},
13 DescriptorProto, FieldDescriptorProto,
14};
15use std::ops::Deref;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::{
19 collections::HashMap,
20 convert::TryInto,
21 fmt::Display,
22 sync::{
23 atomic::{AtomicUsize, Ordering},
24 Arc,
25 },
26};
27use tokio::sync::Semaphore;
28use tokio::task::JoinSet;
29use tonic::{
30 codec::CompressionEncoding,
31 transport::{Channel, ClientTlsConfig},
32 Request, Status, Streaming,
33};
34
35use crate::google::cloud::bigquery::storage::v1::{GetWriteStreamRequest, ProtoRows, WriteStream, WriteStreamView};
36use crate::{
37 auth::Authenticator,
38 error::BQError,
39 google::cloud::bigquery::storage::v1::{
40 append_rows_request::{self, MissingValueInterpretation, ProtoData},
41 big_query_write_client::BigQueryWriteClient,
42 AppendRowsRequest, AppendRowsResponse, ProtoSchema,
43 },
44 BIG_QUERY_V2_URL,
45};
46
47static BIG_QUERY_STORAGE_API_URL: &str = "https://bigquerystorage.googleapis.com";
49static BIGQUERY_STORAGE_API_DOMAIN: &str = "bigquerystorage.googleapis.com";
51const MAX_BATCH_SIZE_BYTES: usize = 9 * 1024 * 1024;
56const MAX_MESSAGE_SIZE_BYTES: usize = 20 * 1024 * 1024;
61const DEFAULT_STREAM_NAME: &str = "_default";
65const MAX_POOL_SIZE: usize = 32;
67
68#[derive(Clone)]
73pub(crate) struct ConnectionPool {
74 pool: Pool<BigQueryWriteClientManager>,
76}
77
78struct BigQueryWriteClientManager;
83
84impl Manager for BigQueryWriteClientManager {
85 type Type = BigQueryWriteClient<Channel>;
86 type Error = BQError;
87
88 async fn create(&self) -> Result<Self::Type, Self::Error> {
93 let tls_config = ClientTlsConfig::new()
97 .domain_name(BIGQUERY_STORAGE_API_DOMAIN)
98 .with_enabled_roots();
99
100 let channel = Channel::from_static(BIG_QUERY_STORAGE_API_URL)
101 .tls_config(tls_config)?
102 .connect()
103 .await?;
104
105 let client = BigQueryWriteClient::new(channel)
106 .max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES)
107 .max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES)
108 .send_compressed(CompressionEncoding::Gzip)
109 .accept_compressed(CompressionEncoding::Gzip);
110
111 Ok(client)
112 }
113
114 async fn recycle(
120 &self,
121 _conn: &mut Self::Type,
122 _metrics: &deadpool::managed::Metrics,
123 ) -> deadpool::managed::RecycleResult<Self::Error> {
124 Ok(())
125 }
126}
127
128impl ConnectionPool {
129 async fn new() -> Result<Self, BQError> {
134 let manager = BigQueryWriteClientManager;
135 let pool = Pool::builder(manager)
136 .max_size(MAX_POOL_SIZE)
137 .queue_mode(QueueMode::Fifo)
140 .build()
141 .map_err(|e| BQError::ConnectionPoolError(format!("Failed to create connection pool: {}", e)))?;
142
143 Ok(Self { pool })
144 }
145
146 async fn get_client(&self) -> Result<Object<BigQueryWriteClientManager>, BQError> {
151 self.pool
152 .get()
153 .await
154 .map_err(|e| BQError::ConnectionPoolError(format!("Failed to get connection from pool: {}", e)))
155 }
156}
157
158#[derive(Debug, Copy, Clone)]
160pub enum ColumnType {
161 Double,
162 Float,
163 Int64,
164 Uint64,
165 Int32,
166 Fixed64,
167 Fixed32,
168 Bool,
169 String,
170 Bytes,
171 Uint32,
172 Sfixed32,
173 Sfixed64,
174 Sint32,
175 Sint64,
176}
177
178impl From<ColumnType> for Type {
179 fn from(value: ColumnType) -> Self {
184 match value {
185 ColumnType::Double => Type::Double,
186 ColumnType::Float => Type::Float,
187 ColumnType::Int64 => Type::Int64,
188 ColumnType::Uint64 => Type::Uint64,
189 ColumnType::Int32 => Type::Int32,
190 ColumnType::Fixed64 => Type::Fixed64,
191 ColumnType::Fixed32 => Type::Fixed32,
192 ColumnType::Bool => Type::Bool,
193 ColumnType::String => Type::String,
194 ColumnType::Bytes => Type::Bytes,
195 ColumnType::Uint32 => Type::Uint32,
196 ColumnType::Sfixed32 => Type::Sfixed32,
197 ColumnType::Sfixed64 => Type::Sfixed64,
198 ColumnType::Sint32 => Type::Sint32,
199 ColumnType::Sint64 => Type::Sint64,
200 }
201 }
202}
203
204#[derive(Debug, Copy, Clone)]
206pub enum ColumnMode {
207 Nullable,
209 Required,
211 Repeated,
213}
214
215impl From<ColumnMode> for Label {
216 fn from(value: ColumnMode) -> Self {
221 match value {
222 ColumnMode::Nullable => Label::Optional,
223 ColumnMode::Required => Label::Required,
224 ColumnMode::Repeated => Label::Repeated,
225 }
226 }
227}
228
229#[derive(Debug, Clone)]
236pub struct FieldDescriptor {
237 pub number: u32,
239 pub name: String,
241 pub typ: ColumnType,
243 pub mode: ColumnMode,
245}
246
247#[derive(Debug, Clone)]
253pub struct TableDescriptor {
254 pub field_descriptors: Vec<FieldDescriptor>,
256}
257
258#[derive(Debug)]
264pub struct TableBatch<M> {
265 pub stream_name: StreamName,
267 pub table_descriptor: Arc<TableDescriptor>,
269 pub rows: Vec<M>,
271}
272
273impl<M> TableBatch<M> {
274 pub fn new(stream_name: StreamName, table_descriptor: Arc<TableDescriptor>, rows: Vec<M>) -> Self {
279 Self {
280 stream_name,
281 table_descriptor,
282 rows,
283 }
284 }
285}
286
287#[derive(Debug)]
293pub struct BatchAppendResult {
294 pub batch_index: usize,
299 pub responses: Vec<Result<AppendRowsResponse, Status>>,
305 pub bytes_sent: usize,
307}
308
309impl BatchAppendResult {
310 pub fn new(batch_index: usize, responses: Vec<Result<AppendRowsResponse, Status>>, bytes_sent: usize) -> Self {
315 Self {
316 batch_index,
317 responses,
318 bytes_sent,
319 }
320 }
321
322 pub fn is_success(&self) -> bool {
327 self.responses.iter().all(|result| result.is_ok())
328 }
329}
330
331#[derive(Debug, Clone)]
337pub struct StreamName {
338 project: String,
340 dataset: String,
342 table: String,
344 stream: String,
346}
347
348impl StreamName {
349 pub fn new(project: String, dataset: String, table: String, stream: String) -> StreamName {
354 StreamName {
355 project,
356 dataset,
357 table,
358 stream,
359 }
360 }
361
362 pub fn new_default(project: String, dataset: String, table: String) -> StreamName {
367 StreamName {
368 project,
369 dataset,
370 table,
371 stream: DEFAULT_STREAM_NAME.to_string(),
372 }
373 }
374}
375
376impl Display for StreamName {
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 let StreamName {
384 project,
385 dataset,
386 table,
387 stream,
388 } = self;
389 f.write_fmt(format_args!(
390 "projects/{project}/datasets/{dataset}/tables/{table}/streams/{stream}"
391 ))
392 }
393}
394
395#[pin_project]
402#[derive(Debug)]
403pub struct AppendRequestsStream<M> {
404 #[pin]
406 batch: Vec<M>,
407 proto_schema: ProtoSchema,
409 stream_name: StreamName,
411 trace_id: String,
413 current_index: usize,
415 include_schema_next: bool,
420 bytes_sent_counter: Arc<AtomicUsize>,
422}
423
424impl<M> AppendRequestsStream<M> {
425 fn new(
431 batch: Vec<M>,
432 proto_schema: ProtoSchema,
433 stream_name: StreamName,
434 trace_id: String,
435 bytes_sent_counter: Arc<AtomicUsize>,
436 ) -> Self {
437 Self {
438 batch,
439 proto_schema,
440 stream_name,
441 trace_id,
442 current_index: 0,
443 include_schema_next: true,
444 bytes_sent_counter,
445 }
446 }
447}
448
449impl<M> Stream for AppendRequestsStream<M>
450where
451 M: Message,
452{
453 type Item = AppendRowsRequest;
454
455 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
462 let this = self.project();
463
464 if *this.current_index >= this.batch.len() {
465 return Poll::Ready(None);
466 }
467
468 let mut serialized_rows = Vec::new();
469 let mut total_size = 0;
470 let mut processed_count = 0;
471
472 for msg in this.batch.iter().skip(*this.current_index) {
475 let size = msg.encoded_len();
478 if total_size + size > MAX_BATCH_SIZE_BYTES && !serialized_rows.is_empty() {
479 break;
480 }
481
482 let encoded = msg.encode_to_vec();
484 debug_assert_eq!(
485 encoded.len(),
486 size,
487 "prost::encoded_len disagrees with encode_to_vec length"
488 );
489
490 serialized_rows.push(encoded);
491 total_size += size;
492 processed_count += 1;
493 }
494
495 if serialized_rows.is_empty() {
496 return Poll::Ready(None);
497 }
498
499 let proto_rows = ProtoRows { serialized_rows };
500 let proto_data = ProtoData {
501 writer_schema: if *this.include_schema_next {
502 Some(this.proto_schema.clone())
503 } else {
504 None
505 },
506 rows: Some(proto_rows),
507 };
508
509 let append_rows_request = AppendRowsRequest {
510 write_stream: this.stream_name.to_string(),
511 offset: None,
512 trace_id: this.trace_id.clone(),
513 missing_value_interpretations: HashMap::new(),
514 default_missing_value_interpretation: MissingValueInterpretation::Unspecified.into(),
515 rows: Some(append_rows_request::Rows::ProtoRows(proto_data)),
516 };
517
518 let request_bytes = append_rows_request.encoded_len();
520 this.bytes_sent_counter.fetch_add(request_bytes, Ordering::Relaxed);
521
522 *this.current_index += processed_count;
523 if *this.include_schema_next {
525 *this.include_schema_next = false;
526 }
527
528 Poll::Ready(Some(append_rows_request))
529 }
530}
531
532#[derive(Clone)]
534pub struct StorageApi {
535 connection_pool: ConnectionPool,
537 auth: Arc<dyn Authenticator>,
539 base_url: String,
541}
542
543impl StorageApi {
544 pub(crate) async fn new(auth: Arc<dyn Authenticator>) -> Result<Self, BQError> {
546 let connection_pool = ConnectionPool::new().await?;
547
548 Ok(Self {
549 connection_pool,
550 auth,
551 base_url: BIG_QUERY_V2_URL.to_string(),
552 })
553 }
554
555 pub(crate) fn with_base_url(&mut self, base_url: String) -> &mut Self {
560 self.base_url = base_url;
561 self
562 }
563
564 pub fn create_rows<M: Message>(
574 table_descriptor: &TableDescriptor,
575 rows: &[M],
576 max_size_bytes: usize,
577 ) -> (append_rows_request::Rows, usize) {
578 let proto_schema = Self::create_proto_schema(table_descriptor);
579
580 let mut serialized_rows = Vec::new();
581 let mut total_size = 0;
582
583 for row in rows {
584 let row_size = row.encoded_len();
586 if total_size + row_size > max_size_bytes {
587 break;
588 }
589
590 let encoded_row = row.encode_to_vec();
591 debug_assert_eq!(
592 encoded_row.len(),
593 row_size,
594 "prost::encoded_len disagrees with encode_to_vec length"
595 );
596
597 serialized_rows.push(encoded_row);
598 total_size += row_size;
599 }
600
601 let num_rows_processed = serialized_rows.len();
602
603 let proto_rows = ProtoRows { serialized_rows };
604
605 let proto_data = ProtoData {
606 writer_schema: Some(proto_schema),
607 rows: Some(proto_rows),
608 };
609
610 (append_rows_request::Rows::ProtoRows(proto_data), num_rows_processed)
611 }
612
613 pub async fn get_write_stream(
619 &mut self,
620 stream_name: &StreamName,
621 view: WriteStreamView,
622 ) -> Result<WriteStream, BQError> {
623 let get_write_stream_request = GetWriteStreamRequest {
624 name: stream_name.to_string(),
625 view: view.into(),
626 };
627
628 let request = Self::new_authorized_request(self.auth.clone(), get_write_stream_request).await?;
629 let mut client = self.connection_pool.get_client().await?;
630 let response = client.get_write_stream(request).await?;
631 let write_stream = response.into_inner();
632
633 Ok(write_stream)
634 }
635
636 pub async fn append_rows(
642 &mut self,
643 stream_name: &StreamName,
644 rows: append_rows_request::Rows,
645 trace_id: String,
646 ) -> Result<Streaming<AppendRowsResponse>, BQError> {
647 let append_rows_request = AppendRowsRequest {
648 write_stream: stream_name.to_string(),
649 offset: None,
650 trace_id,
651 missing_value_interpretations: HashMap::new(),
652 default_missing_value_interpretation: MissingValueInterpretation::Unspecified.into(),
653 rows: Some(rows),
654 };
655
656 let request =
657 Self::new_authorized_request(self.auth.clone(), tokio_stream::iter(vec![append_rows_request])).await?;
658 let mut client = self.connection_pool.get_client().await?;
659 let response = client.append_rows(request).await?;
660 let streaming = response.into_inner();
661
662 Ok(streaming)
663 }
664
665 pub async fn append_table_batches_concurrent<M>(
672 &self,
673 table_batches: Vec<TableBatch<M>>,
674 max_concurrent_streams: usize,
675 trace_id: &str,
676 ) -> Result<Vec<BatchAppendResult>, BQError>
677 where
678 M: Message + Send + 'static,
679 {
680 if table_batches.is_empty() {
681 return Ok(Vec::new());
682 }
683
684 let batches_num = table_batches.len();
685 let semaphore = Arc::new(Semaphore::new(max_concurrent_streams));
686
687 let mut join_set = JoinSet::new();
688 for (idx, table_batch) in table_batches.into_iter().enumerate() {
689 let permit = semaphore.clone().acquire_owned().await?;
691
692 let stream_name = table_batch.stream_name.clone();
693 let table_descriptor = table_batch.table_descriptor;
694 let rows = table_batch.rows;
695 let trace_id = trace_id.to_string();
696 let client = self.clone();
697
698 join_set.spawn(async move {
699 let proto_schema = Self::create_proto_schema(&table_descriptor);
702
703 let bytes_sent_counter = Arc::new(AtomicUsize::new(0));
705
706 let request_stream =
709 AppendRequestsStream::new(rows, proto_schema, stream_name, trace_id, bytes_sent_counter.clone());
710
711 let mut batch_responses = Vec::new();
712
713 match Self::new_authorized_request(client.auth.clone(), request_stream).await {
717 Ok(request) => match client.connection_pool.get_client().await {
718 Ok(write_client) => {
719 let mut client = write_client.deref().clone();
725
726 drop(write_client);
729
730 match client.append_rows(request).await {
731 Ok(response) => {
732 let mut streaming_response = response.into_inner();
733 while let Some(response) = streaming_response.next().await {
734 batch_responses.push(response);
735 }
736 }
737 Err(status) => {
738 batch_responses.push(Err(status));
739 }
740 }
741 }
742 Err(pool_err) => {
743 batch_responses.push(Err(Status::unknown(format!("Pool error: {}", pool_err))));
744 }
745 },
746 Err(err) => {
747 batch_responses.push(Err(Status::unknown(err.to_string())));
748 }
749 }
750
751 drop(permit);
753
754 BatchAppendResult::new(idx, batch_responses, bytes_sent_counter.load(Ordering::Relaxed))
757 });
758 }
759
760 let mut batch_results = Vec::with_capacity(batches_num);
762 while let Some(batch_result) = join_set.join_next().await {
763 let batch_result = batch_result?;
764 batch_results.push(batch_result);
765 }
766
767 Ok(batch_results)
768 }
769
770 async fn new_authorized_request<T>(auth: Arc<dyn Authenticator>, message: T) -> Result<Request<T>, BQError> {
776 let access_token = auth.access_token().await?;
777 let bearer_token = format!("Bearer {access_token}");
778 let bearer_value = bearer_token.as_str().try_into()?;
779
780 let mut request = Request::new(message);
781 let meta = request.metadata_mut();
782 meta.insert("authorization", bearer_value);
783
784 Ok(request)
785 }
786
787 fn create_field_descriptors(table_descriptor: &TableDescriptor) -> Vec<FieldDescriptorProto> {
793 table_descriptor
794 .field_descriptors
795 .iter()
796 .map(|fd| {
797 let typ: Type = fd.typ.into();
798 let label: Label = fd.mode.into();
799
800 FieldDescriptorProto {
801 name: Some(fd.name.clone()),
802 number: Some(fd.number as i32),
803 label: Some(label.into()),
804 r#type: Some(typ.into()),
805 type_name: None,
806 extendee: None,
807 default_value: None,
808 oneof_index: None,
809 json_name: None,
810 options: None,
811 proto3_optional: None,
812 }
813 })
814 .collect()
815 }
816
817 fn create_proto_descriptor(field_descriptors: Vec<FieldDescriptorProto>) -> DescriptorProto {
823 DescriptorProto {
824 name: Some("table_schema".to_string()),
825 field: field_descriptors,
826 extension: vec![],
827 nested_type: vec![],
828 enum_type: vec![],
829 extension_range: vec![],
830 oneof_decl: vec![],
831 options: None,
832 reserved_range: vec![],
833 reserved_name: vec![],
834 }
835 }
836
837 fn create_proto_schema(table_descriptor: &TableDescriptor) -> ProtoSchema {
843 let field_descriptors = Self::create_field_descriptors(table_descriptor);
844 let proto_descriptor = Self::create_proto_descriptor(field_descriptors);
845
846 ProtoSchema {
847 proto_descriptor: Some(proto_descriptor),
848 }
849 }
850}
851
852#[cfg(test)]
853pub mod test {
854 use prost::Message;
855 use std::sync::Arc;
856 use std::time::{Duration, SystemTime};
857 use tokio_stream::StreamExt;
858
859 use crate::model::dataset::Dataset;
860 use crate::model::field_type::FieldType;
861 use crate::model::table::Table;
862 use crate::model::table_field_schema::TableFieldSchema;
863 use crate::model::table_schema::TableSchema;
864 use crate::storage::{
865 ColumnMode, ColumnType, ConnectionPool, FieldDescriptor, StorageApi, StreamName, TableBatch, TableDescriptor,
866 };
867 use crate::{env_vars, Client};
868
869 #[derive(Clone, PartialEq, Message)]
870 struct Actor {
871 #[prost(int32, tag = "1")]
872 actor_id: i32,
873 #[prost(string, tag = "2")]
874 first_name: String,
875 #[prost(string, tag = "3")]
876 last_name: String,
877 #[prost(string, tag = "4")]
878 last_update: String,
879 }
880
881 fn create_test_table_descriptor() -> Arc<TableDescriptor> {
882 let field_descriptors = vec![
883 FieldDescriptor {
884 name: "actor_id".to_string(),
885 number: 1,
886 typ: ColumnType::Int64,
887 mode: ColumnMode::Nullable,
888 },
889 FieldDescriptor {
890 name: "first_name".to_string(),
891 number: 2,
892 typ: ColumnType::String,
893 mode: ColumnMode::Nullable,
894 },
895 FieldDescriptor {
896 name: "last_name".to_string(),
897 number: 3,
898 typ: ColumnType::String,
899 mode: ColumnMode::Nullable,
900 },
901 FieldDescriptor {
902 name: "last_update".to_string(),
903 number: 4,
904 typ: ColumnType::String,
905 mode: ColumnMode::Nullable,
906 },
907 ];
908
909 Arc::new(TableDescriptor { field_descriptors })
910 }
911
912 async fn setup_test_table(
913 client: &mut Client,
914 project_id: &str,
915 dataset_id: &str,
916 table_id: &str,
917 ) -> Result<(), Box<dyn std::error::Error>> {
918 client.dataset().delete_if_exists(project_id, dataset_id, true).await;
919
920 let created_dataset = client.dataset().create(Dataset::new(project_id, dataset_id)).await?;
921 assert_eq!(created_dataset.id, Some(format!("{project_id}:{dataset_id}")));
922
923 let table = Table::new(
924 project_id,
925 dataset_id,
926 table_id,
927 TableSchema::new(vec![
928 TableFieldSchema::new("actor_id", FieldType::Int64),
929 TableFieldSchema::new("first_name", FieldType::String),
930 TableFieldSchema::new("last_name", FieldType::String),
931 TableFieldSchema::new("last_update", FieldType::Timestamp),
932 ]),
933 );
934 let created_table = client
935 .table()
936 .create(
937 table
938 .description("A table used for unit tests")
939 .label("owner", "me")
940 .label("env", "prod")
941 .expiration_time(SystemTime::now() + Duration::from_secs(3600)),
942 )
943 .await?;
944 assert_eq!(created_table.table_reference.table_id, table_id.to_string());
945
946 Ok(())
947 }
948
949 fn create_test_actor(id: i32, first_name: &str) -> Actor {
950 Actor {
951 actor_id: id,
952 first_name: first_name.to_string(),
953 last_name: "Doe".to_string(),
954 last_update: "2007-02-15 09:34:33 UTC".to_string(),
955 }
956 }
957
958 async fn call_append_rows(
959 client: &mut Client,
960 table_descriptor: &TableDescriptor,
961 stream_name: &StreamName,
962 trace_id: String,
963 mut rows: &[Actor],
964 max_size: usize,
965 ) -> Result<u8, Box<dyn std::error::Error>> {
966 let mut num_append_rows_calls = 0;
971 loop {
972 let (encoded_rows, num_processed) = StorageApi::create_rows(table_descriptor, rows, max_size);
973 let mut streaming = client
974 .storage_mut()
975 .append_rows(stream_name, encoded_rows, trace_id.clone())
976 .await?;
977
978 num_append_rows_calls += 1;
979
980 while let Some(response) = streaming.next().await {
981 response?;
982 }
983
984 if num_processed == rows.len() {
986 break;
987 }
988
989 rows = &rows[num_processed..];
991 }
992
993 Ok(num_append_rows_calls)
994 }
995
996 #[tokio::test]
997 async fn test_connection_pool() {
998 let connection_pool = ConnectionPool::new().await.unwrap();
999
1000 let client1 = connection_pool.get_client().await.unwrap();
1002 let client2 = connection_pool.get_client().await.unwrap();
1003
1004 assert!(std::ptr::addr_of!(*client1) != std::ptr::addr_of!(*client2));
1008
1009 drop(client1);
1011 drop(client2);
1012
1013 let client3 = connection_pool.get_client().await.unwrap();
1015 drop(client3);
1016 }
1017
1018 #[tokio::test]
1019 async fn test_append_rows() {
1020 let (ref project_id, ref dataset_id, ref table_id, ref sa_key) = env_vars();
1021 let dataset_id = &format!("{dataset_id}_storage");
1022
1023 let mut client = Client::from_service_account_key_file(sa_key).await.unwrap();
1024
1025 setup_test_table(&mut client, project_id, dataset_id, table_id)
1026 .await
1027 .unwrap();
1028
1029 let table_descriptor = create_test_table_descriptor();
1030 let actor1 = create_test_actor(1, "John");
1031 let actor2 = create_test_actor(2, "Jane");
1032
1033 let stream_name = StreamName::new_default(project_id.clone(), dataset_id.clone(), table_id.clone());
1034 let trace_id = "test_client".to_string();
1035
1036 let rows: &[Actor] = &[actor1, actor2];
1037
1038 let max_size = 9 * 1024 * 1024; let num_append_rows_calls = call_append_rows(
1040 &mut client,
1041 &table_descriptor,
1042 &stream_name,
1043 trace_id.clone(),
1044 rows,
1045 max_size,
1046 )
1047 .await
1048 .unwrap();
1049 assert_eq!(num_append_rows_calls, 1);
1050
1051 let max_size = 50; let num_append_rows_calls =
1055 call_append_rows(&mut client, &table_descriptor, &stream_name, trace_id, rows, max_size)
1056 .await
1057 .unwrap();
1058 assert_eq!(num_append_rows_calls, 2);
1059 }
1060
1061 #[tokio::test]
1062 async fn test_append_table_batches_concurrent() {
1063 let (ref project_id, ref dataset_id, ref table_id, ref sa_key) = env_vars();
1064 let dataset_id = &format!("{dataset_id}_storage_table_batches");
1065
1066 let mut client = Client::from_service_account_key_file(sa_key).await.unwrap();
1067
1068 setup_test_table(&mut client, project_id, dataset_id, table_id)
1069 .await
1070 .unwrap();
1071
1072 let table_descriptor = create_test_table_descriptor();
1073 let stream_name = StreamName::new_default(project_id.clone(), dataset_id.clone(), table_id.clone());
1074 let trace_id = "test_table_batches";
1075
1076 let batch1 = TableBatch::new(
1078 stream_name.clone(),
1079 table_descriptor.clone(),
1080 vec![
1081 create_test_actor(1, "John"),
1082 create_test_actor(2, "Jane"),
1083 create_test_actor(3, "Bob"),
1084 create_test_actor(4, "Alice"),
1085 ],
1086 );
1087
1088 let batch2 = TableBatch::new(
1089 stream_name.clone(),
1090 table_descriptor.clone(),
1091 vec![create_test_actor(5, "Charlie"), create_test_actor(6, "Dave")],
1092 );
1093
1094 let batch3 = TableBatch::new(stream_name, table_descriptor, vec![create_test_actor(7, "Eve")]);
1095
1096 let table_batches = vec![batch1, batch2, batch3];
1097
1098 let batch_responses = client
1101 .storage_mut()
1102 .append_table_batches_concurrent(table_batches, 2, trace_id)
1103 .await
1104 .unwrap();
1105
1106 assert_eq!(batch_responses.len(), 3);
1108
1109 let mut total_bytes_across_all_batches = 0;
1111 for batch_result in batch_responses {
1112 assert!(
1114 batch_result.is_success(),
1115 "Batch {} should be successful.",
1116 batch_result.batch_index,
1117 );
1118
1119 for response in &batch_result.responses {
1121 assert!(response.is_ok(), "Response should be successful: {:?}", response);
1122 }
1123
1124 let bytes_sent = batch_result.bytes_sent;
1126 assert!(
1127 bytes_sent > 0,
1128 "Bytes sent should be greater than 0 for batch {}, got: {}",
1129 batch_result.batch_index,
1130 bytes_sent
1131 );
1132
1133 total_bytes_across_all_batches += bytes_sent;
1134 }
1135
1136 assert!(
1138 total_bytes_across_all_batches > 0,
1139 "Total bytes sent across all batches should be greater than 0"
1140 );
1141 }
1142}