gcp_bigquery_client/
storage.rs

1//! BigQuery Storage Write API client for high-throughput data streaming.
2//!
3//! This module provides an implementation of the BigQuery Storage Write API,
4//! enabling efficient streaming of structured data to BigQuery tables.
5
6use deadpool::managed::{Manager, Object, Pool, QueueMode};
7use futures::stream::Stream;
8use futures::StreamExt;
9use pin_project::pin_project;
10use prost::Message;
11use prost_types::{
12    field_descriptor_proto::{Label, Type},
13    DescriptorProto, FieldDescriptorProto,
14};
15use std::ops::Deref;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18use std::{
19    collections::HashMap,
20    convert::TryInto,
21    fmt::Display,
22    sync::{
23        atomic::{AtomicUsize, Ordering},
24        Arc,
25    },
26};
27use tokio::sync::Semaphore;
28use tokio::task::JoinSet;
29use tonic::{
30    codec::CompressionEncoding,
31    transport::{Channel, ClientTlsConfig},
32    Request, Status, Streaming,
33};
34
35use crate::google::cloud::bigquery::storage::v1::{GetWriteStreamRequest, ProtoRows, WriteStream, WriteStreamView};
36use crate::{
37    auth::Authenticator,
38    error::BQError,
39    google::cloud::bigquery::storage::v1::{
40        append_rows_request::{self, MissingValueInterpretation, ProtoData},
41        big_query_write_client::BigQueryWriteClient,
42        AppendRowsRequest, AppendRowsResponse, ProtoSchema,
43    },
44    BIG_QUERY_V2_URL,
45};
46
47/// Base URL for the BigQuery Storage Write API endpoint.
48static BIG_QUERY_STORAGE_API_URL: &str = "https://bigquerystorage.googleapis.com";
49/// Domain name for BigQuery Storage API used in TLS configuration.
50static BIGQUERY_STORAGE_API_DOMAIN: &str = "bigquerystorage.googleapis.com";
51/// Maximum size limit for batched append requests in bytes.
52///
53/// Set to 9MB to provide safety margin under the 10MB BigQuery API limit,
54/// accounting for request metadata overhead.
55const MAX_BATCH_SIZE_BYTES: usize = 9 * 1024 * 1024;
56/// Maximum message size for tonic gRPC client configuration.
57///
58/// Set to 20MB to accommodate large response messages and provide headroom
59/// for metadata while staying within reasonable memory bounds.
60const MAX_MESSAGE_SIZE_BYTES: usize = 20 * 1024 * 1024;
61/// The name of the default stream in BigQuery.
62///
63/// This stream is a special built-in stream that always exists for a table.
64const DEFAULT_STREAM_NAME: &str = "_default";
65/// Max number of connections in the pool.
66const MAX_POOL_SIZE: usize = 32;
67
68/// Connection pool for managing multiple gRPC clients to BigQuery Storage Write API.
69///
70/// Uses deadpool to maintain a pool of persistent connections with proper
71/// resource management, automatic cleanup, and efficient connection reuse.
72#[derive(Clone)]
73pub(crate) struct ConnectionPool {
74    /// Deadpool pool of BigQuery Storage Write API clients.
75    pool: Pool<BigQueryWriteClientManager>,
76}
77
78/// Manager for creating and managing BigQuery Storage Write API gRPC clients.
79///
80/// Implements the deadpool Manager trait to handle connection lifecycle
81/// including creation, health checks, and cleanup of gRPC clients.
82struct BigQueryWriteClientManager;
83
84impl Manager for BigQueryWriteClientManager {
85    type Type = BigQueryWriteClient<Channel>;
86    type Error = BQError;
87
88    /// Creates a new gRPC client for BigQuery Storage Write API.
89    ///
90    /// Establishes a secure TLS connection with compression and size limits
91    /// configured for optimal performance.
92    async fn create(&self) -> Result<Self::Type, Self::Error> {
93        // Since Tonic 0.12.0, TLS root certificates are no longer included by default.
94        // They must now be specified explicitly.
95        // See: https://github.com/hyperium/tonic/pull/1731
96        let tls_config = ClientTlsConfig::new()
97            .domain_name(BIGQUERY_STORAGE_API_DOMAIN)
98            .with_enabled_roots();
99
100        let channel = Channel::from_static(BIG_QUERY_STORAGE_API_URL)
101            .tls_config(tls_config)?
102            .connect()
103            .await?;
104
105        let client = BigQueryWriteClient::new(channel)
106            .max_encoding_message_size(MAX_MESSAGE_SIZE_BYTES)
107            .max_decoding_message_size(MAX_MESSAGE_SIZE_BYTES)
108            .send_compressed(CompressionEncoding::Gzip)
109            .accept_compressed(CompressionEncoding::Gzip);
110
111        Ok(client)
112    }
113
114    /// Recycles a gRPC client connection.
115    ///
116    /// Currently always returns Ok(()) as gRPC clients don't require
117    /// special recycling. In the future, this could perform health checks
118    /// or connection validation.
119    async fn recycle(
120        &self,
121        _conn: &mut Self::Type,
122        _metrics: &deadpool::managed::Metrics,
123    ) -> deadpool::managed::RecycleResult<Self::Error> {
124        Ok(())
125    }
126}
127
128impl ConnectionPool {
129    /// Creates a new connection pool with the specified number of clients.
130    ///
131    /// Establishes a managed pool that creates connections on-demand and
132    /// recycles them efficiently for optimal performance.
133    async fn new() -> Result<Self, BQError> {
134        let manager = BigQueryWriteClientManager;
135        let pool = Pool::builder(manager)
136            .max_size(MAX_POOL_SIZE)
137            // We must use Fifo since we want to always get the least recently used connection to cycle
138            // through connections in the pool.
139            .queue_mode(QueueMode::Fifo)
140            .build()
141            .map_err(|e| BQError::ConnectionPoolError(format!("Failed to create connection pool: {}", e)))?;
142
143        Ok(Self { pool })
144    }
145
146    /// Retrieves a client from the pool.
147    ///
148    /// Returns a managed connection object that automatically returns
149    /// the connection to the pool when dropped.
150    async fn get_client(&self) -> Result<Object<BigQueryWriteClientManager>, BQError> {
151        self.pool
152            .get()
153            .await
154            .map_err(|e| BQError::ConnectionPoolError(format!("Failed to get connection from pool: {}", e)))
155    }
156}
157
158/// Supported protobuf column types for BigQuery schema mapping.
159#[derive(Debug, Copy, Clone)]
160pub enum ColumnType {
161    Double,
162    Float,
163    Int64,
164    Uint64,
165    Int32,
166    Fixed64,
167    Fixed32,
168    Bool,
169    String,
170    Bytes,
171    Uint32,
172    Sfixed32,
173    Sfixed64,
174    Sint32,
175    Sint64,
176}
177
178impl From<ColumnType> for Type {
179    /// Converts [`ColumnType`] to protobuf [`Type`] enum value.
180    ///
181    /// Maps each column type variant to its corresponding protobuf type
182    /// identifier used in BigQuery Storage Write API schema definitions.
183    fn from(value: ColumnType) -> Self {
184        match value {
185            ColumnType::Double => Type::Double,
186            ColumnType::Float => Type::Float,
187            ColumnType::Int64 => Type::Int64,
188            ColumnType::Uint64 => Type::Uint64,
189            ColumnType::Int32 => Type::Int32,
190            ColumnType::Fixed64 => Type::Fixed64,
191            ColumnType::Fixed32 => Type::Fixed32,
192            ColumnType::Bool => Type::Bool,
193            ColumnType::String => Type::String,
194            ColumnType::Bytes => Type::Bytes,
195            ColumnType::Uint32 => Type::Uint32,
196            ColumnType::Sfixed32 => Type::Sfixed32,
197            ColumnType::Sfixed64 => Type::Sfixed64,
198            ColumnType::Sint32 => Type::Sint32,
199            ColumnType::Sint64 => Type::Sint64,
200        }
201    }
202}
203
204/// Field cardinality modes for BigQuery schema fields.
205#[derive(Debug, Copy, Clone)]
206pub enum ColumnMode {
207    /// Field may contain null values.
208    Nullable,
209    /// Field must always contain a value.
210    Required,
211    /// Field contains an array of values.
212    Repeated,
213}
214
215impl From<ColumnMode> for Label {
216    /// Converts [`ColumnMode`] to protobuf [`Label`] enum value.
217    ///
218    /// Maps field cardinality modes to their corresponding protobuf labels
219    /// used in BigQuery Storage Write API schema definitions.
220    fn from(value: ColumnMode) -> Self {
221        match value {
222            ColumnMode::Nullable => Label::Optional,
223            ColumnMode::Required => Label::Required,
224            ColumnMode::Repeated => Label::Repeated,
225        }
226    }
227}
228
229/// Metadata descriptor for a single field in a BigQuery table schema.
230///
231/// Contains the complete field definition including data type, cardinality mode,
232/// and protobuf field number required for BigQuery Storage Write API operations.
233/// Each field descriptor maps directly to a protobuf field descriptor in the
234/// generated schema.
235#[derive(Debug, Clone)]
236pub struct FieldDescriptor {
237    /// Unique field number starting from 1, incrementing for each field.
238    pub number: u32,
239    /// Name of the field as it appears in BigQuery.
240    pub name: String,
241    /// Data type of the field.
242    pub typ: ColumnType,
243    /// Cardinality mode of the field.
244    pub mode: ColumnMode,
245}
246
247/// Complete schema definition for a BigQuery table.
248///
249/// Aggregates all field descriptors that define the table's structure.
250/// Used to generate protobuf schemas for BigQuery Storage Write API operations
251/// and validate row data before transmission.
252#[derive(Debug, Clone)]
253pub struct TableDescriptor {
254    /// Collection of field descriptors defining the table schema.
255    pub field_descriptors: Vec<FieldDescriptor>,
256}
257
258/// Collection of rows targeting a specific BigQuery table for batch processing.
259///
260/// Encapsulates rows with their destination stream and schema metadata,
261/// enabling efficient batch operations and optimal parallelism distribution
262/// across multiple tables in concurrent append operations.
263#[derive(Debug)]
264pub struct TableBatch<M> {
265    /// Target stream identifier for the append operations.
266    pub stream_name: StreamName,
267    /// Schema descriptor for the target table.
268    pub table_descriptor: Arc<TableDescriptor>,
269    /// Collection of rows to be appended to the table.
270    pub rows: Vec<M>,
271}
272
273impl<M> TableBatch<M> {
274    /// Creates a new table batch targeting the specified stream.
275    ///
276    /// Combines rows with their destination metadata to form a complete
277    /// batch ready for processing by append operations.
278    pub fn new(stream_name: StreamName, table_descriptor: Arc<TableDescriptor>, rows: Vec<M>) -> Self {
279        Self {
280            stream_name,
281            table_descriptor,
282            rows,
283        }
284    }
285}
286
287/// Result of processing a single table batch in concurrent append operations.
288///
289/// Contains the batch processing results along with metadata about the operation,
290/// including the original batch index for result ordering and total bytes sent
291/// for monitoring and debugging purposes.
292#[derive(Debug)]
293pub struct BatchAppendResult {
294    /// Original index of the batch in the input vector.
295    ///
296    /// Allows callers to correlate results with their original batch ordering
297    /// even when results are returned in completion order rather than submission order.
298    pub batch_index: usize,
299    /// Collection of append operation responses for this batch.
300    ///
301    /// Each batch may generate multiple append requests due to size limits,
302    /// resulting in multiple responses. All responses must be checked for
303    /// errors to ensure complete batch success.
304    pub responses: Vec<Result<AppendRowsResponse, Status>>,
305    /// Total bytes sent for this batch across all requests.
306    pub bytes_sent: usize,
307}
308
309impl BatchAppendResult {
310    /// Creates a new batch append result.
311    ///
312    /// Combines all result metadata into a single cohesive structure
313    /// for easier handling by calling code.
314    pub fn new(batch_index: usize, responses: Vec<Result<AppendRowsResponse, Status>>, bytes_sent: usize) -> Self {
315        Self {
316            batch_index,
317            responses,
318            bytes_sent,
319        }
320    }
321
322    /// Returns true if all responses in this batch are successful.
323    ///
324    /// Convenience method to quickly check batch success without
325    /// iterating through individual responses.
326    pub fn is_success(&self) -> bool {
327        self.responses.iter().all(|result| result.is_ok())
328    }
329}
330
331/// Hierarchical identifier for BigQuery write streams.
332///
333/// Represents the complete resource path structure used by BigQuery to
334/// uniquely identify tables and their associated write streams within
335/// the Google Cloud resource hierarchy.
336#[derive(Debug, Clone)]
337pub struct StreamName {
338    /// Google Cloud project identifier.
339    project: String,
340    /// BigQuery dataset identifier within the project.
341    dataset: String,
342    /// BigQuery table identifier within the dataset.
343    table: String,
344    /// Write stream identifier for the table.
345    stream: String,
346}
347
348impl StreamName {
349    /// Creates a stream name with all components specified.
350    ///
351    /// Constructs a fully qualified stream identifier using custom
352    /// project, dataset, table, and stream components.
353    pub fn new(project: String, dataset: String, table: String, stream: String) -> StreamName {
354        StreamName {
355            project,
356            dataset,
357            table,
358            stream,
359        }
360    }
361
362    /// Creates a stream name using the default stream identifier.
363    ///
364    /// Uses "_default" as the stream component, which is the standard
365    /// stream identifier for most BigQuery write operations.
366    pub fn new_default(project: String, dataset: String, table: String) -> StreamName {
367        StreamName {
368            project,
369            dataset,
370            table,
371            stream: DEFAULT_STREAM_NAME.to_string(),
372        }
373    }
374}
375
376impl Display for StreamName {
377    /// Formats the stream name as a BigQuery resource path.
378    ///
379    /// Produces the fully qualified resource identifier expected by
380    /// BigQuery Storage Write API in the format:
381    /// `projects/{project}/datasets/{dataset}/tables/{table}/streams/{stream}`
382    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        let StreamName {
384            project,
385            dataset,
386            table,
387            stream,
388        } = self;
389        f.write_fmt(format_args!(
390            "projects/{project}/datasets/{dataset}/tables/{table}/streams/{stream}"
391        ))
392    }
393}
394
395/// Streaming adapter that converts message batches into [`AppendRowsRequest`] objects.
396///
397/// Automatically chunks large batches into multiple requests while respecting
398/// the 10MB BigQuery API size limit. If a single row exceeds the configured
399/// limit, it is sent alone and may be rejected by the server. Implements [`Stream`] for seamless
400/// integration with async streaming workflows and gRPC client operations.
401#[pin_project]
402#[derive(Debug)]
403pub struct AppendRequestsStream<M> {
404    /// Collection of messages to be converted into append requests.
405    #[pin]
406    batch: Vec<M>,
407    /// Protobuf schema definition for the target table.
408    proto_schema: ProtoSchema,
409    /// Target stream identifier for the append operations.
410    stream_name: StreamName,
411    /// Unique identifier for tracing and debugging requests.
412    trace_id: String,
413    /// Current position in the batch being processed.
414    current_index: usize,
415    /// Whether to include writer schema in the next request (first only).
416    ///
417    /// This boolean is used under the assumption that a batch of append requests belongs to the same
418    /// table and has no schema differences between the rows.
419    include_schema_next: bool,
420    /// Shared atomic counter for tracking total bytes sent across all requests in this stream.
421    bytes_sent_counter: Arc<AtomicUsize>,
422}
423
424impl<M> AppendRequestsStream<M> {
425    /// Creates a new streaming adapter from message batch components.
426    ///
427    /// Initializes the stream with all necessary metadata for generating
428    /// properly formatted append requests. The schema is included only
429    /// in the first request of the stream.
430    fn new(
431        batch: Vec<M>,
432        proto_schema: ProtoSchema,
433        stream_name: StreamName,
434        trace_id: String,
435        bytes_sent_counter: Arc<AtomicUsize>,
436    ) -> Self {
437        Self {
438            batch,
439            proto_schema,
440            stream_name,
441            trace_id,
442            current_index: 0,
443            include_schema_next: true,
444            bytes_sent_counter,
445        }
446    }
447}
448
449impl<M> Stream for AppendRequestsStream<M>
450where
451    M: Message,
452{
453    type Item = AppendRowsRequest;
454
455    /// Produces the next append request from the message batch.
456    ///
457    /// Processes messages sequentially, accumulating them into requests
458    /// until the size limit is reached. Returns [`Poll::Ready(None)`]
459    /// when all messages have been consumed. Each request contains the
460    /// maximum number of messages that fit within size constraints.
461    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
462        let this = self.project();
463
464        if *this.current_index >= this.batch.len() {
465            return Poll::Ready(None);
466        }
467
468        let mut serialized_rows = Vec::new();
469        let mut total_size = 0;
470        let mut processed_count = 0;
471
472        // Process messages from `current_index` onwards. We do not change the vector while processing
473        // to avoid reallocations which are unnecessary.
474        for msg in this.batch.iter().skip(*this.current_index) {
475            // First, check the encoded length to avoid performing a full encode
476            // on the first message that would exceed the limit and be dropped.
477            let size = msg.encoded_len();
478            if total_size + size > MAX_BATCH_SIZE_BYTES && !serialized_rows.is_empty() {
479                break;
480            }
481
482            // Safe to encode now and include the row in this request chunk.
483            let encoded = msg.encode_to_vec();
484            debug_assert_eq!(
485                encoded.len(),
486                size,
487                "prost::encoded_len disagrees with encode_to_vec length"
488            );
489
490            serialized_rows.push(encoded);
491            total_size += size;
492            processed_count += 1;
493        }
494
495        if serialized_rows.is_empty() {
496            return Poll::Ready(None);
497        }
498
499        let proto_rows = ProtoRows { serialized_rows };
500        let proto_data = ProtoData {
501            writer_schema: if *this.include_schema_next {
502                Some(this.proto_schema.clone())
503            } else {
504                None
505            },
506            rows: Some(proto_rows),
507        };
508
509        let append_rows_request = AppendRowsRequest {
510            write_stream: this.stream_name.to_string(),
511            offset: None,
512            trace_id: this.trace_id.clone(),
513            missing_value_interpretations: HashMap::new(),
514            default_missing_value_interpretation: MissingValueInterpretation::Unspecified.into(),
515            rows: Some(append_rows_request::Rows::ProtoRows(proto_data)),
516        };
517
518        // Track the total bytes being sent using encoded_len
519        let request_bytes = append_rows_request.encoded_len();
520        this.bytes_sent_counter.fetch_add(request_bytes, Ordering::Relaxed);
521
522        *this.current_index += processed_count;
523        // After the first request, avoid sending schema again in this stream
524        if *this.include_schema_next {
525            *this.include_schema_next = false;
526        }
527
528        Poll::Ready(Some(append_rows_request))
529    }
530}
531
532/// High-level client for BigQuery Storage Write API operations.
533#[derive(Clone)]
534pub struct StorageApi {
535    /// Connection pool for gRPC clients to BigQuery Storage Write API.
536    connection_pool: ConnectionPool,
537    /// Authentication provider for API requests.
538    auth: Arc<dyn Authenticator>,
539    /// Base URL for BigQuery API endpoints.
540    base_url: String,
541}
542
543impl StorageApi {
544    /// Creates a new storage API client instance.
545    pub(crate) async fn new(auth: Arc<dyn Authenticator>) -> Result<Self, BQError> {
546        let connection_pool = ConnectionPool::new().await?;
547
548        Ok(Self {
549            connection_pool,
550            auth,
551            base_url: BIG_QUERY_V2_URL.to_string(),
552        })
553    }
554
555    /// Configures a custom base URL for BigQuery API endpoints.
556    ///
557    /// Primarily used for testing scenarios with mock or alternative
558    /// BigQuery API endpoints. Returns a mutable reference for chaining.
559    pub(crate) fn with_base_url(&mut self, base_url: String) -> &mut Self {
560        self.base_url = base_url;
561        self
562    }
563
564    /// Encodes message rows into protobuf format with size management.
565    ///
566    /// Processes as many rows as possible while respecting the specified
567    /// size limit. Returns the encoded protobuf data and the count of
568    /// rows successfully processed. When the returned count is less than
569    /// the input slice length, additional calls are required for remaining rows.
570    ///
571    /// The size limit should be below 10MB to accommodate request metadata
572    /// overhead; 9MB provides a safe margin.
573    pub fn create_rows<M: Message>(
574        table_descriptor: &TableDescriptor,
575        rows: &[M],
576        max_size_bytes: usize,
577    ) -> (append_rows_request::Rows, usize) {
578        let proto_schema = Self::create_proto_schema(table_descriptor);
579
580        let mut serialized_rows = Vec::new();
581        let mut total_size = 0;
582
583        for row in rows {
584            // Use encoded_len to avoid encoding a row that won't fit.
585            let row_size = row.encoded_len();
586            if total_size + row_size > max_size_bytes {
587                break;
588            }
589
590            let encoded_row = row.encode_to_vec();
591            debug_assert_eq!(
592                encoded_row.len(),
593                row_size,
594                "prost::encoded_len disagrees with encode_to_vec length"
595            );
596
597            serialized_rows.push(encoded_row);
598            total_size += row_size;
599        }
600
601        let num_rows_processed = serialized_rows.len();
602
603        let proto_rows = ProtoRows { serialized_rows };
604
605        let proto_data = ProtoData {
606            writer_schema: Some(proto_schema),
607            rows: Some(proto_rows),
608        };
609
610        (append_rows_request::Rows::ProtoRows(proto_data), num_rows_processed)
611    }
612
613    /// Retrieves metadata for a BigQuery write stream.
614    ///
615    /// Fetches stream information including schema definition and state
616    /// details according to the specified view level. Higher view levels
617    /// provide more comprehensive information but may have higher latency.
618    pub async fn get_write_stream(
619        &mut self,
620        stream_name: &StreamName,
621        view: WriteStreamView,
622    ) -> Result<WriteStream, BQError> {
623        let get_write_stream_request = GetWriteStreamRequest {
624            name: stream_name.to_string(),
625            view: view.into(),
626        };
627
628        let request = Self::new_authorized_request(self.auth.clone(), get_write_stream_request).await?;
629        let mut client = self.connection_pool.get_client().await?;
630        let response = client.get_write_stream(request).await?;
631        let write_stream = response.into_inner();
632
633        Ok(write_stream)
634    }
635
636    /// Appends rows to a BigQuery table using the Storage Write API.
637    ///
638    /// Transmits the provided rows to the specified stream and returns
639    /// a streaming response for processing results. The trace ID enables
640    /// request tracking across distributed systems for debugging.
641    pub async fn append_rows(
642        &mut self,
643        stream_name: &StreamName,
644        rows: append_rows_request::Rows,
645        trace_id: String,
646    ) -> Result<Streaming<AppendRowsResponse>, BQError> {
647        let append_rows_request = AppendRowsRequest {
648            write_stream: stream_name.to_string(),
649            offset: None,
650            trace_id,
651            missing_value_interpretations: HashMap::new(),
652            default_missing_value_interpretation: MissingValueInterpretation::Unspecified.into(),
653            rows: Some(rows),
654        };
655
656        let request =
657            Self::new_authorized_request(self.auth.clone(), tokio_stream::iter(vec![append_rows_request])).await?;
658        let mut client = self.connection_pool.get_client().await?;
659        let response = client.append_rows(request).await?;
660        let streaming = response.into_inner();
661
662        Ok(streaming)
663    }
664
665    /// Appends rows from multiple table batches with concurrent processing.
666    ///
667    /// Returns a collection of batch results containing responses, metadata,
668    /// and bytes sent for each batch processed. Results are ordered by
669    /// completion, not by submission; use `BatchAppendResult::batch_index`
670    /// to correlate with the original input order.
671    pub async fn append_table_batches_concurrent<M>(
672        &self,
673        table_batches: Vec<TableBatch<M>>,
674        max_concurrent_streams: usize,
675        trace_id: &str,
676    ) -> Result<Vec<BatchAppendResult>, BQError>
677    where
678        M: Message + Send + 'static,
679    {
680        if table_batches.is_empty() {
681            return Ok(Vec::new());
682        }
683
684        let batches_num = table_batches.len();
685        let semaphore = Arc::new(Semaphore::new(max_concurrent_streams));
686
687        let mut join_set = JoinSet::new();
688        for (idx, table_batch) in table_batches.into_iter().enumerate() {
689            // Acquire a concurrency slot and hold it until responses are fully drained.
690            let permit = semaphore.clone().acquire_owned().await?;
691
692            let stream_name = table_batch.stream_name.clone();
693            let table_descriptor = table_batch.table_descriptor;
694            let rows = table_batch.rows;
695            let trace_id = trace_id.to_string();
696            let client = self.clone();
697
698            join_set.spawn(async move {
699                // We compute the proto schema once for the entire batch. We might want to compute it
700                // once per stream but for now this is fine.
701                let proto_schema = Self::create_proto_schema(&table_descriptor);
702
703                // Create an atomic counter for tracking bytes sent for this batch.
704                let bytes_sent_counter = Arc::new(AtomicUsize::new(0));
705
706                // Build the request stream which will split the request into multiple requests if
707                // necessary.
708                let request_stream =
709                    AppendRequestsStream::new(rows, proto_schema, stream_name, trace_id, bytes_sent_counter.clone());
710
711                let mut batch_responses = Vec::new();
712
713                // Make the request for append rows and poll the response stream until exhausted. Since
714                // we might send multiple requests over the same stream, we will also receive multiple
715                // responses.
716                match Self::new_authorized_request(client.auth.clone(), request_stream).await {
717                    Ok(request) => match client.connection_pool.get_client().await {
718                        Ok(write_client) => {
719                            // We clone the client to cheaply issue multiple requests over the same connection
720                            // without holding onto the original connection object.
721                            //
722                            // This approach utilizes connections efficiently, assuming each call to the pool
723                            // returns a different connection — which holds true when using `Fifo` queuing.
724                            let mut client = write_client.deref().clone();
725
726                            // We return the connection to the pool immediately so that it can be reused
727                            // by another task.
728                            drop(write_client);
729
730                            match client.append_rows(request).await {
731                                Ok(response) => {
732                                    let mut streaming_response = response.into_inner();
733                                    while let Some(response) = streaming_response.next().await {
734                                        batch_responses.push(response);
735                                    }
736                                }
737                                Err(status) => {
738                                    batch_responses.push(Err(status));
739                                }
740                            }
741                        }
742                        Err(pool_err) => {
743                            batch_responses.push(Err(Status::unknown(format!("Pool error: {}", pool_err))));
744                        }
745                    },
746                    Err(err) => {
747                        batch_responses.push(Err(Status::unknown(err.to_string())));
748                    }
749                }
750
751                // Free the concurrency slot only after fully draining responses or after error.
752                drop(permit);
753
754                // We load the atomic directly in the result to avoid exposing atomics. By doing this
755                // we assume that when this code path is reached, the stream has been fully drained.
756                BatchAppendResult::new(idx, batch_responses, bytes_sent_counter.load(Ordering::Relaxed))
757            });
758        }
759
760        // Collect all task results in the order of completion.
761        let mut batch_results = Vec::with_capacity(batches_num);
762        while let Some(batch_result) = join_set.join_next().await {
763            let batch_result = batch_result?;
764            batch_results.push(batch_result);
765        }
766
767        Ok(batch_results)
768    }
769
770    /// Creates an authenticated gRPC request with Bearer token authorization.
771    ///
772    /// Retrieves an access token from the authenticator and attaches it
773    /// as a Bearer token in the request authorization header. Used by
774    /// all Storage Write API operations requiring authentication.
775    async fn new_authorized_request<T>(auth: Arc<dyn Authenticator>, message: T) -> Result<Request<T>, BQError> {
776        let access_token = auth.access_token().await?;
777        let bearer_token = format!("Bearer {access_token}");
778        let bearer_value = bearer_token.as_str().try_into()?;
779
780        let mut request = Request::new(message);
781        let meta = request.metadata_mut();
782        meta.insert("authorization", bearer_value);
783
784        Ok(request)
785    }
786
787    /// Converts table field descriptors to protobuf field descriptors.
788    ///
789    /// Transforms the high-level field descriptors into the protobuf
790    /// format required by BigQuery Storage Write API schema definitions.
791    /// Maps column types and modes to their protobuf equivalents.
792    fn create_field_descriptors(table_descriptor: &TableDescriptor) -> Vec<FieldDescriptorProto> {
793        table_descriptor
794            .field_descriptors
795            .iter()
796            .map(|fd| {
797                let typ: Type = fd.typ.into();
798                let label: Label = fd.mode.into();
799
800                FieldDescriptorProto {
801                    name: Some(fd.name.clone()),
802                    number: Some(fd.number as i32),
803                    label: Some(label.into()),
804                    r#type: Some(typ.into()),
805                    type_name: None,
806                    extendee: None,
807                    default_value: None,
808                    oneof_index: None,
809                    json_name: None,
810                    options: None,
811                    proto3_optional: None,
812                }
813            })
814            .collect()
815    }
816
817    /// Creates a protobuf descriptor from field descriptors.
818    ///
819    /// Wraps field descriptors in a [`DescriptorProto`] structure with
820    /// the standard table schema name. Used as an intermediate step
821    /// in protobuf schema generation.
822    fn create_proto_descriptor(field_descriptors: Vec<FieldDescriptorProto>) -> DescriptorProto {
823        DescriptorProto {
824            name: Some("table_schema".to_string()),
825            field: field_descriptors,
826            extension: vec![],
827            nested_type: vec![],
828            enum_type: vec![],
829            extension_range: vec![],
830            oneof_decl: vec![],
831            options: None,
832            reserved_range: vec![],
833            reserved_name: vec![],
834        }
835    }
836
837    /// Generates a complete protobuf schema from table descriptor.
838    ///
839    /// Creates the final [`ProtoSchema`] structure containing all
840    /// field definitions required for BigQuery Storage Write API
841    /// operations. This schema is included in append requests.
842    fn create_proto_schema(table_descriptor: &TableDescriptor) -> ProtoSchema {
843        let field_descriptors = Self::create_field_descriptors(table_descriptor);
844        let proto_descriptor = Self::create_proto_descriptor(field_descriptors);
845
846        ProtoSchema {
847            proto_descriptor: Some(proto_descriptor),
848        }
849    }
850}
851
852#[cfg(test)]
853pub mod test {
854    use prost::Message;
855    use std::sync::Arc;
856    use std::time::{Duration, SystemTime};
857    use tokio_stream::StreamExt;
858
859    use crate::model::dataset::Dataset;
860    use crate::model::field_type::FieldType;
861    use crate::model::table::Table;
862    use crate::model::table_field_schema::TableFieldSchema;
863    use crate::model::table_schema::TableSchema;
864    use crate::storage::{
865        ColumnMode, ColumnType, ConnectionPool, FieldDescriptor, StorageApi, StreamName, TableBatch, TableDescriptor,
866    };
867    use crate::{env_vars, Client};
868
869    #[derive(Clone, PartialEq, Message)]
870    struct Actor {
871        #[prost(int32, tag = "1")]
872        actor_id: i32,
873        #[prost(string, tag = "2")]
874        first_name: String,
875        #[prost(string, tag = "3")]
876        last_name: String,
877        #[prost(string, tag = "4")]
878        last_update: String,
879    }
880
881    fn create_test_table_descriptor() -> Arc<TableDescriptor> {
882        let field_descriptors = vec![
883            FieldDescriptor {
884                name: "actor_id".to_string(),
885                number: 1,
886                typ: ColumnType::Int64,
887                mode: ColumnMode::Nullable,
888            },
889            FieldDescriptor {
890                name: "first_name".to_string(),
891                number: 2,
892                typ: ColumnType::String,
893                mode: ColumnMode::Nullable,
894            },
895            FieldDescriptor {
896                name: "last_name".to_string(),
897                number: 3,
898                typ: ColumnType::String,
899                mode: ColumnMode::Nullable,
900            },
901            FieldDescriptor {
902                name: "last_update".to_string(),
903                number: 4,
904                typ: ColumnType::String,
905                mode: ColumnMode::Nullable,
906            },
907        ];
908
909        Arc::new(TableDescriptor { field_descriptors })
910    }
911
912    async fn setup_test_table(
913        client: &mut Client,
914        project_id: &str,
915        dataset_id: &str,
916        table_id: &str,
917    ) -> Result<(), Box<dyn std::error::Error>> {
918        client.dataset().delete_if_exists(project_id, dataset_id, true).await;
919
920        let created_dataset = client.dataset().create(Dataset::new(project_id, dataset_id)).await?;
921        assert_eq!(created_dataset.id, Some(format!("{project_id}:{dataset_id}")));
922
923        let table = Table::new(
924            project_id,
925            dataset_id,
926            table_id,
927            TableSchema::new(vec![
928                TableFieldSchema::new("actor_id", FieldType::Int64),
929                TableFieldSchema::new("first_name", FieldType::String),
930                TableFieldSchema::new("last_name", FieldType::String),
931                TableFieldSchema::new("last_update", FieldType::Timestamp),
932            ]),
933        );
934        let created_table = client
935            .table()
936            .create(
937                table
938                    .description("A table used for unit tests")
939                    .label("owner", "me")
940                    .label("env", "prod")
941                    .expiration_time(SystemTime::now() + Duration::from_secs(3600)),
942            )
943            .await?;
944        assert_eq!(created_table.table_reference.table_id, table_id.to_string());
945
946        Ok(())
947    }
948
949    fn create_test_actor(id: i32, first_name: &str) -> Actor {
950        Actor {
951            actor_id: id,
952            first_name: first_name.to_string(),
953            last_name: "Doe".to_string(),
954            last_update: "2007-02-15 09:34:33 UTC".to_string(),
955        }
956    }
957
958    async fn call_append_rows(
959        client: &mut Client,
960        table_descriptor: &TableDescriptor,
961        stream_name: &StreamName,
962        trace_id: String,
963        mut rows: &[Actor],
964        max_size: usize,
965    ) -> Result<u8, Box<dyn std::error::Error>> {
966        // This loop is needed because the AppendRows API has a payload size limit of 10MB and the create_rows
967        // function may not process all the rows in the rows slice due to the 10MB limit. Even though in this
968        // example we are only sending two rows (which won't breach the 10MB limit), in a real-world scenario,
969        // we may have to send more rows and the loop will be needed to process all the rows.
970        let mut num_append_rows_calls = 0;
971        loop {
972            let (encoded_rows, num_processed) = StorageApi::create_rows(table_descriptor, rows, max_size);
973            let mut streaming = client
974                .storage_mut()
975                .append_rows(stream_name, encoded_rows, trace_id.clone())
976                .await?;
977
978            num_append_rows_calls += 1;
979
980            while let Some(response) = streaming.next().await {
981                response?;
982            }
983
984            // All the rows have been processed
985            if num_processed == rows.len() {
986                break;
987            }
988
989            // Process the remaining rows
990            rows = &rows[num_processed..];
991        }
992
993        Ok(num_append_rows_calls)
994    }
995
996    #[tokio::test]
997    async fn test_connection_pool() {
998        let connection_pool = ConnectionPool::new().await.unwrap();
999
1000        // Test that we can get multiple connections from the pool
1001        let client1 = connection_pool.get_client().await.unwrap();
1002        let client2 = connection_pool.get_client().await.unwrap();
1003
1004        // Connections should be different instances but both valid
1005        // Just verify they exist and are usable (we can't easily test the same connection
1006        // is reused without dropping and re-acquiring)
1007        assert!(std::ptr::addr_of!(*client1) != std::ptr::addr_of!(*client2));
1008
1009        // Drop connections to return them to the pool
1010        drop(client1);
1011        drop(client2);
1012
1013        // Get another connection to verify pool recycling works
1014        let client3 = connection_pool.get_client().await.unwrap();
1015        drop(client3);
1016    }
1017
1018    #[tokio::test]
1019    async fn test_append_rows() {
1020        let (ref project_id, ref dataset_id, ref table_id, ref sa_key) = env_vars();
1021        let dataset_id = &format!("{dataset_id}_storage");
1022
1023        let mut client = Client::from_service_account_key_file(sa_key).await.unwrap();
1024
1025        setup_test_table(&mut client, project_id, dataset_id, table_id)
1026            .await
1027            .unwrap();
1028
1029        let table_descriptor = create_test_table_descriptor();
1030        let actor1 = create_test_actor(1, "John");
1031        let actor2 = create_test_actor(2, "Jane");
1032
1033        let stream_name = StreamName::new_default(project_id.clone(), dataset_id.clone(), table_id.clone());
1034        let trace_id = "test_client".to_string();
1035
1036        let rows: &[Actor] = &[actor1, actor2];
1037
1038        let max_size = 9 * 1024 * 1024; // 9 MB
1039        let num_append_rows_calls = call_append_rows(
1040            &mut client,
1041            &table_descriptor,
1042            &stream_name,
1043            trace_id.clone(),
1044            rows,
1045            max_size,
1046        )
1047        .await
1048        .unwrap();
1049        assert_eq!(num_append_rows_calls, 1);
1050
1051        // It was found after experimenting that one row in this test encodes to about 38 bytes
1052        // We artificially limit the size of the rows to test that the loop processes all the rows
1053        let max_size = 50; // 50 bytes
1054        let num_append_rows_calls =
1055            call_append_rows(&mut client, &table_descriptor, &stream_name, trace_id, rows, max_size)
1056                .await
1057                .unwrap();
1058        assert_eq!(num_append_rows_calls, 2);
1059    }
1060
1061    #[tokio::test]
1062    async fn test_append_table_batches_concurrent() {
1063        let (ref project_id, ref dataset_id, ref table_id, ref sa_key) = env_vars();
1064        let dataset_id = &format!("{dataset_id}_storage_table_batches");
1065
1066        let mut client = Client::from_service_account_key_file(sa_key).await.unwrap();
1067
1068        setup_test_table(&mut client, project_id, dataset_id, table_id)
1069            .await
1070            .unwrap();
1071
1072        let table_descriptor = create_test_table_descriptor();
1073        let stream_name = StreamName::new_default(project_id.clone(), dataset_id.clone(), table_id.clone());
1074        let trace_id = "test_table_batches";
1075
1076        // Create multiple table batches (all targeting the same table in this test)
1077        let batch1 = TableBatch::new(
1078            stream_name.clone(),
1079            table_descriptor.clone(),
1080            vec![
1081                create_test_actor(1, "John"),
1082                create_test_actor(2, "Jane"),
1083                create_test_actor(3, "Bob"),
1084                create_test_actor(4, "Alice"),
1085            ],
1086        );
1087
1088        let batch2 = TableBatch::new(
1089            stream_name.clone(),
1090            table_descriptor.clone(),
1091            vec![create_test_actor(5, "Charlie"), create_test_actor(6, "Dave")],
1092        );
1093
1094        let batch3 = TableBatch::new(stream_name, table_descriptor, vec![create_test_actor(7, "Eve")]);
1095
1096        let table_batches = vec![batch1, batch2, batch3];
1097
1098        // Test with a concurrency limit of 2 to assert that all batches are processed even though
1099        // the supplied batches are more than the limit.
1100        let batch_responses = client
1101            .storage_mut()
1102            .append_table_batches_concurrent(table_batches, 2, trace_id)
1103            .await
1104            .unwrap();
1105
1106        // We expect 3 responses per batch (one for each batch)
1107        assert_eq!(batch_responses.len(), 3);
1108
1109        // Verify all responses are successful and track total bytes sent.
1110        let mut total_bytes_across_all_batches = 0;
1111        for batch_result in batch_responses {
1112            // Verify the batch was processed successfully.
1113            assert!(
1114                batch_result.is_success(),
1115                "Batch {} should be successful.",
1116                batch_result.batch_index,
1117            );
1118
1119            // Verify each individual response for detailed error reporting.
1120            for response in &batch_result.responses {
1121                assert!(response.is_ok(), "Response should be successful: {:?}", response);
1122            }
1123
1124            // Verify that some bytes were sent (should be greater than 0).
1125            let bytes_sent = batch_result.bytes_sent;
1126            assert!(
1127                bytes_sent > 0,
1128                "Bytes sent should be greater than 0 for batch {}, got: {}",
1129                batch_result.batch_index,
1130                bytes_sent
1131            );
1132
1133            total_bytes_across_all_batches += bytes_sent;
1134        }
1135
1136        // Verify that we sent bytes across all batches
1137        assert!(
1138            total_bytes_across_all_batches > 0,
1139            "Total bytes sent across all batches should be greater than 0"
1140        );
1141    }
1142}