1use hyper::client::connect::Connect;
23use yup_oauth2::authenticator::Authenticator;
24
25use prost_types::Timestamp;
26use tonic::metadata::MetadataValue;
27use tonic::transport::{Channel, ClientTlsConfig};
28use tonic::{Request, Streaming};
29
30use crate::googleapis::big_query_read_client::BigQueryReadClient;
31use crate::googleapis::{
32 read_session::{TableModifiers, TableReadOptions},
33 CreateReadSessionRequest, DataFormat, ReadRowsRequest, ReadRowsResponse,
34 ReadSession as BigQueryReadSession, ReadStream,
35};
36use crate::Error;
37use crate::RowsStreamReader;
38
39static API_ENDPOINT: &'static str = "https://bigquerystorage.googleapis.com";
40static API_DOMAIN: &'static str = "bigquerystorage.googleapis.com";
41static API_SCOPE: &'static str = "https://www.googleapis.com/auth/bigquery";
42
43pub struct Table {
47 project_id: String,
48 dataset_id: String,
49 table_id: String,
50}
51
52impl Table {
53 pub fn new(project_id: &str, dataset_id: &str, table_id: &str) -> Self {
54 Self {
55 project_id: project_id.to_string(),
56 dataset_id: dataset_id.to_string(),
57 table_id: table_id.to_string(),
58 }
59 }
60}
61
62impl std::fmt::Display for Table {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(
65 f,
66 "projects/{}/datasets/{}/tables/{}",
67 self.project_id, self.dataset_id, self.table_id
68 )
69 }
70}
71
72macro_rules! read_session_builder {
73 {
74 $(
75 $(#[$m:meta])*
76 $field:ident: $ty:path,
77 )*
78 } => {
79 #[derive(Default)]
80 struct ReadSessionBuilderOpts {
81 $(
82 $field: Option<$ty>,
83 )*
84 }
85
86 pub struct ReadSessionBuilder<'a, T> {
89 client: &'a mut Client<T>,
90 table: Table,
91 opts: ReadSessionBuilderOpts
92 }
93
94 impl<'a, T> ReadSessionBuilder<'a, T> {
95 fn new(client: &'a mut Client<T>, table: Table) -> Self {
96 let opts = ReadSessionBuilderOpts::default();
97 Self { client, table, opts }
98 }
99
100 $(
101 $(#[$m])*
102 pub fn $field(mut self, $field: $ty) -> Self {
103 self.opts.$field = Some($field);
104 self
105 }
106 )*
107 }
108 };
109}
110
111read_session_builder! {
112 #[doc = "Sets the data format of the output data. Defaults to Arrow if not set."]
113 data_format: DataFormat,
114 #[doc = "Sets the snapshot time of the table. If not set, interpreted as now."]
115 snapshot_time: Timestamp,
116 #[doc = "Names of the fields in the table that should be read. If empty or not set, all fields will be read. If the specified field is a nested field, all the sub-fields in the field will be selected. The output field order is unrelated to the order of fields in selected_fields."]
117 selected_fields: Vec<String>,
118 #[doc = "SQL text filtering statement, similar to a `WHERE` clause in a query. Aggregates are not supported.\n"]
119 #[doc = "Examples: \n
120- `int_field > 5` \n
121- `date_field = CAST('2014-9-27' as DATE)` \n
122- `nullable_field is not NULL` \n
123- `st_equals(geo_field, st_geofromtext(\"POINT(2, 2)\"))` \n
124- `numeric_field BETWEEN 1.0 AND 5.0`"]
125 row_restriction: String,
126 #[doc = "Max initial number of streams. If unset or zero, the server will provide a value of streams so as to produce reasonable throughput. Must be non-negative. The number of streams may be lower than the requested number, depending on the amount parallelism that is reasonable for the table. Error will be returned if the max count is greater than the current system max limit of 1,000."]
127 max_stream_count: i32,
128 #[doc = "The request project that owns the session. If not set, defaults to the project owning the table to be read."]
129 parent_project_id: String,
130}
131
132impl<'a, C> ReadSessionBuilder<'a, C>
133where
134 C: Connect + Clone + Send + Sync + 'static,
135{
136 pub async fn build(self) -> Result<ReadSession<'a, C>, Error> {
139 let table = self.table.to_string();
140
141 let mut inner = BigQueryReadSession {
142 table,
143 ..Default::default()
144 };
145
146 let data_format = self.opts.data_format.unwrap_or(DataFormat::Arrow);
147 inner.set_data_format(data_format);
148
149 if let Some(snapshot_time) = self.opts.snapshot_time {
150 inner.table_modifiers = Some(TableModifiers {
151 snapshot_time: Some(snapshot_time),
152 });
153 }
154
155 let mut tro = TableReadOptions::default();
156 if let Some(selected_fields) = self.opts.selected_fields {
157 tro.selected_fields = selected_fields;
158 }
159
160 if let Some(row_restriction) = self.opts.row_restriction {
161 tro.row_restriction = row_restriction;
162 }
163
164 let parent_project_id = self.opts.parent_project_id.unwrap_or(self.table.project_id);
165 let parent = format!("projects/{}", parent_project_id);
166 let max_stream_count = self.opts.max_stream_count.unwrap_or_default();
167
168 let req = CreateReadSessionRequest {
169 parent,
170 read_session: Some(inner),
171 max_stream_count,
172 };
173
174 let inner = self.client.create_read_session(req).await?;
175
176 Ok(ReadSession {
177 client: self.client,
178 inner,
179 })
180 }
181}
182
183pub struct ReadSession<'a, C> {
186 client: &'a mut Client<C>,
187 inner: BigQueryReadSession,
188}
189
190impl<'a, C> ReadSession<'a, C>
191where
192 C: Connect + Clone + Send + Sync + 'static,
193{
194 pub async fn next_stream(&mut self) -> Result<Option<RowsStreamReader>, Error> {
196 match self.inner.streams.pop() {
197 Some(ReadStream { name }) => {
198 let rows_stream = self.client.read_stream_rows(&name).await?;
199 let schema = self
200 .inner
201 .schema
202 .clone()
203 .ok_or(Error::invalid("empty schema response"))?;
204 Ok(Some(RowsStreamReader::new(schema, rows_stream)))
205 }
206 None => Ok(None),
207 }
208 }
209}
210
211pub struct Client<C> {
213 auth: Authenticator<C>,
214 big_query_read_client: BigQueryReadClient<Channel>,
215}
216
217impl<C> Client<C>
218where
219 C: Connect + Clone + Send + Sync + 'static,
220{
221 pub async fn new(auth: Authenticator<C>) -> Result<Self, Error> {
223 let tls_config = ClientTlsConfig::new().domain_name(API_DOMAIN);
224 let channel = Channel::from_static(API_ENDPOINT)
225 .tls_config(tls_config)?
226 .connect()
227 .await?;
228 let big_query_read_client = BigQueryReadClient::new(channel);
229 Ok(Self {
230 auth,
231 big_query_read_client,
232 })
233 }
234
235 pub fn read_session_builder(&mut self, table: Table) -> ReadSessionBuilder<'_, C> {
237 ReadSessionBuilder::new(self, table)
238 }
239 async fn new_request<D>(&self, t: D, params: &str) -> Result<Request<D>, Error> {
240 let token = self.auth.token(&[API_SCOPE]).await?;
241 let bearer_token = format!("Bearer {}", token.as_str());
242 let bearer_value = MetadataValue::from_str(&bearer_token)?;
243 let mut req = Request::new(t);
244 let meta = req.metadata_mut();
245 meta.insert("authorization", bearer_value);
246 meta.insert("x-goog-request-params", MetadataValue::from_str(params)?);
247 Ok(req)
248 }
249 async fn create_read_session(
250 &mut self,
251 req: CreateReadSessionRequest,
252 ) -> Result<BigQueryReadSession, Error> {
253 let table_uri = &req.read_session.as_ref().unwrap().table;
254 let params = format!("read_session.table={}", table_uri);
255 let wrapped = self.new_request(req, ¶ms).await?;
256
257 let read_session = self
258 .big_query_read_client
259 .create_read_session(wrapped)
260 .await?
261 .into_inner();
262 Ok(read_session)
263 }
264 async fn read_stream_rows(
265 &mut self,
266 stream: &str,
267 ) -> Result<Streaming<ReadRowsResponse>, Error> {
268 let req = ReadRowsRequest {
269 read_stream: stream.to_string(),
270 offset: 0, };
272 let params = format!("read_stream={}", req.read_stream);
273 let wrapped = self.new_request(req, ¶ms).await?;
274 let read_rows_response = self
275 .big_query_read_client
276 .read_rows(wrapped)
277 .await?
278 .into_inner();
279 Ok(read_rows_response)
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[tokio::test]
288 async fn read_a_table_with_arrow() {
289 let sa_key = yup_oauth2::read_service_account_key("clientsecret.json")
290 .await
291 .unwrap();
292 let auth = yup_oauth2::ServiceAccountAuthenticator::builder(sa_key)
293 .build()
294 .await
295 .unwrap();
296
297 let mut client = Client::new(auth).await.unwrap();
298
299 let test_table = Table::new("bigquery-public-data", "london_bicycles", "cycle_stations");
300
301 let mut read_session = client
302 .read_session_builder(test_table)
303 .parent_project_id("openquery-public-testing".to_string())
304 .build()
305 .await
306 .unwrap();
307
308 let mut num_rows = 0;
309
310 while let Some(stream_reader) = read_session.next_stream().await.unwrap() {
311 let mut arrow_stream_reader = stream_reader.into_arrow_reader().await.unwrap();
312 while let Some(record_batch) = arrow_stream_reader.next() {
313 num_rows += record_batch.unwrap().num_rows();
314 }
315 }
316
317 assert_eq!(num_rows, 789);
318 }
319}