Skip to main content

gcloud_bigquery/storage_write/stream/
mod.rs

1use crate::grpc::apiv1::conn_pool::ConnectionManager;
2use crate::storage_write::flow::FlowController;
3use crate::storage_write::AppendRowsRequestBuilder;
4use google_cloud_gax::grpc::{IntoStreamingRequest, Status, Streaming};
5use google_cloud_googleapis::cloud::bigquery::storage::v1::{
6    AppendRowsRequest, AppendRowsResponse, FinalizeWriteStreamRequest, WriteStream,
7};
8use std::sync::Arc;
9
10pub mod buffered;
11pub mod committed;
12pub mod default;
13pub mod pending;
14
15pub struct Stream {
16    inner: WriteStream,
17    cons: Arc<ConnectionManager>,
18    fc: Option<FlowController>,
19}
20
21impl Stream {
22    pub(crate) fn new(inner: WriteStream, cons: Arc<ConnectionManager>, max_insert_count: usize) -> Self {
23        Self {
24            inner,
25            cons,
26            fc: if max_insert_count > 0 {
27                Some(FlowController::new(max_insert_count))
28            } else {
29                None
30            },
31        }
32    }
33}
34
35pub trait AsStream: Sized {
36    fn as_ref(&self) -> &Stream;
37
38    fn name(&self) -> &str {
39        &self.as_ref().inner.name
40    }
41
42    fn create_streaming_request(
43        &self,
44        rows: Vec<AppendRowsRequestBuilder>,
45    ) -> impl google_cloud_gax::grpc::codegen::tokio_stream::Stream<Item = AppendRowsRequest> {
46        let name = self.name().to_string();
47        async_stream::stream! {
48            for row in rows {
49                yield row.build(&name);
50            }
51        }
52    }
53}
54
55pub(crate) struct ManagedStreamDelegate {}
56
57impl ManagedStreamDelegate {
58    async fn append_rows(
59        stream: &Stream,
60        rows: Vec<AppendRowsRequestBuilder>,
61    ) -> Result<Streaming<AppendRowsResponse>, Status> {
62        let name = stream.inner.name.to_string();
63        let req = async_stream::stream! {
64            for row in rows {
65                yield row.build(&name);
66            }
67        };
68        Self::append_streaming_request(stream, req).await
69    }
70
71    async fn append_streaming_request(
72        stream: &Stream,
73        req: impl IntoStreamingRequest<Message = AppendRowsRequest>,
74    ) -> Result<Streaming<AppendRowsResponse>, Status> {
75        match &stream.fc {
76            None => {
77                let mut client = stream.cons.writer();
78                Ok(client.append_rows(req).await?.into_inner())
79            }
80            Some(fc) => {
81                let permit = fc.acquire().await;
82                let mut client = stream.cons.writer();
83                let result = client.append_rows(req).await?.into_inner();
84                drop(permit);
85                Ok(result)
86            }
87        }
88    }
89}
90
91pub(crate) struct DisposableStreamDelegate {}
92impl DisposableStreamDelegate {
93    async fn finalize(stream: &Stream) -> Result<i64, Status> {
94        let res = stream
95            .cons
96            .writer()
97            .finalize_write_stream(
98                FinalizeWriteStreamRequest {
99                    name: stream.inner.name.to_string(),
100                },
101                None,
102            )
103            .await?
104            .into_inner();
105        Ok(res.row_count)
106    }
107}
108
109#[cfg(test)]
110pub(crate) mod tests {
111    use std::sync::Arc;
112
113    use arrow::array::StringArray;
114    use arrow::datatypes::{DataType, Field, Schema};
115    use arrow::record_batch::RecordBatch;
116    use prost_types::{field_descriptor_proto, DescriptorProto, FieldDescriptorProto};
117
118    use crate::storage_write::AppendRowsRequestBuilder;
119
120    #[derive(Clone, PartialEq, ::prost::Message)]
121    pub(crate) struct TestData {
122        #[prost(string, tag = "1")]
123        pub col_string: String,
124    }
125
126    pub(crate) fn init() {
127        let filter = tracing_subscriber::filter::EnvFilter::from_default_env()
128            .add_directive("google_cloud_bigquery=trace".parse().unwrap());
129        let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
130    }
131
132    pub(crate) fn create_append_rows_request(buf: Vec<Vec<u8>>) -> AppendRowsRequestBuilder {
133        let proto = DescriptorProto {
134            name: Some("TestData".to_string()),
135            field: vec![FieldDescriptorProto {
136                name: Some("col_string".to_string()),
137                number: Some(1),
138                label: None,
139                r#type: Some(field_descriptor_proto::Type::String.into()),
140                type_name: None,
141                extendee: None,
142                default_value: None,
143                oneof_index: None,
144                json_name: None,
145                options: None,
146                proto3_optional: None,
147            }],
148            extension: vec![],
149            nested_type: vec![],
150            enum_type: vec![],
151            extension_range: vec![],
152            oneof_decl: vec![],
153            options: None,
154            reserved_range: vec![],
155            reserved_name: vec![],
156        };
157        AppendRowsRequestBuilder::new(proto, buf)
158    }
159
160    pub(crate) fn create_arrow_append_rows_request(values: Vec<String>) -> AppendRowsRequestBuilder {
161        let schema = Arc::new(Schema::new(vec![Field::new("col_string", DataType::Utf8, false)]));
162        let col = Arc::new(StringArray::from(values));
163        let batch = RecordBatch::try_new(schema, vec![col]).unwrap();
164        AppendRowsRequestBuilder::from_record_batch(&batch).unwrap()
165    }
166}