bigbytes_driver/
rest_api.rs

1// Copyright 2024 Digitrans Inc
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, VecDeque};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::path::Path;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use async_trait::async_trait;
24use log::info;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use databend_client::PresignedResponse;
30use databend_client::QueryResponse;
31use databend_client::{APIClient, SchemaField};
32use bigbytes_driver_core::error::{Error, Result};
33use bigbytes_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
34use bigbytes_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use bigbytes_driver_core::schema::{Schema, SchemaRef};
36
37use crate::conn::{Connection, ConnectionInfo, Reader};
38
39#[derive(Clone)]
40pub struct RestAPIConnection {
41    client: Arc<APIClient>,
42}
43
44#[async_trait]
45impl Connection for RestAPIConnection {
46    async fn info(&self) -> ConnectionInfo {
47        ConnectionInfo {
48            handler: "RestAPI".to_string(),
49            host: self.client.host().to_string(),
50            port: self.client.port(),
51            user: self.client.username(),
52            database: self.client.current_database(),
53            warehouse: self.client.current_warehouse(),
54        }
55    }
56
57    fn last_query_id(&self) -> Option<String> {
58        self.client.last_query_id()
59    }
60
61    async fn close(&self) -> Result<()> {
62        self.client.close().await;
63        Ok(())
64    }
65
66    async fn exec(&self, sql: &str) -> Result<i64> {
67        info!("exec: {}", sql);
68        let mut resp = self.client.start_query(sql).await?;
69        let node_id = resp.node_id.clone();
70        while let Some(next_uri) = resp.next_uri {
71            resp = self
72                .client
73                .query_page(&resp.id, &next_uri, &node_id)
74                .await?;
75        }
76        Ok(resp.stats.progresses.write_progress.rows as i64)
77    }
78
79    async fn kill_query(&self, query_id: &str) -> Result<()> {
80        Ok(self.client.kill_query(query_id).await?)
81    }
82
83    async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
84        info!("query iter: {}", sql);
85        let rows_with_progress = self.query_iter_ext(sql).await?;
86        let rows = rows_with_progress.filter_rows().await;
87        Ok(rows)
88    }
89
90    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
91        info!("query iter ext: {}", sql);
92        let resp = self.client.start_query(sql).await?;
93        let resp = self.wait_for_schema(resp, true).await?;
94        let (schema, rows) = RestAPIRows::<RowWithStats>::from_response(self.client.clone(), resp)?;
95        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
96    }
97
98    // raw data response query, only for test
99    async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
100        info!("query raw iter: {}", sql);
101        let resp = self.client.start_query(sql).await?;
102        let resp = self.wait_for_schema(resp, true).await?;
103        let (schema, rows) =
104            RestAPIRows::<RawRowWithStats>::from_response(self.client.clone(), resp)?;
105        Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
106    }
107
108    async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
109        info!("get presigned url: {} {}", operation, stage);
110        let sql = format!("PRESIGN {} {}", operation, stage);
111        let row = self.query_row(&sql).await?.ok_or_else(|| {
112            Error::InvalidResponse("Empty response from server for presigned request".to_string())
113        })?;
114        let (method, headers, url): (String, String, String) =
115            row.try_into().map_err(Error::Parsing)?;
116        let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
117        Ok(PresignedResponse {
118            method,
119            headers,
120            url,
121        })
122    }
123
124    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
125        self.client.upload_to_stage(stage, data, size).await?;
126        Ok(())
127    }
128
129    async fn load_data(
130        &self,
131        sql: &str,
132        data: Reader,
133        size: u64,
134        file_format_options: Option<BTreeMap<&str, &str>>,
135        copy_options: Option<BTreeMap<&str, &str>>,
136    ) -> Result<ServerStats> {
137        info!(
138            "load data: {}, size: {}, format: {:?}, copy: {:?}",
139            sql, size, file_format_options, copy_options
140        );
141        let now = chrono::Utc::now()
142            .timestamp_nanos_opt()
143            .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
144        let stage = format!("@~/client/load/{}", now);
145
146        let file_format_options =
147            file_format_options.unwrap_or_else(Self::default_file_format_options);
148        let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
149
150        self.upload_to_stage(&stage, data, size).await?;
151        let resp = self
152            .client
153            .insert_with_stage(sql, &stage, file_format_options, copy_options)
154            .await?;
155        Ok(ServerStats::from(resp.stats))
156    }
157
158    async fn load_file(
159        &self,
160        sql: &str,
161        fp: &Path,
162        format_options: Option<BTreeMap<&str, &str>>,
163        copy_options: Option<BTreeMap<&str, &str>>,
164    ) -> Result<ServerStats> {
165        info!(
166            "load file: {}, file: {:?}, format: {:?}, copy: {:?}",
167            sql, fp, format_options, copy_options
168        );
169        let file = File::open(fp).await?;
170        let metadata = file.metadata().await?;
171        let size = metadata.len();
172        let data = BufReader::new(file);
173        let mut format_options = format_options.unwrap_or_else(Self::default_file_format_options);
174        if !format_options.contains_key("type") {
175            let file_type = fp
176                .extension()
177                .ok_or_else(|| Error::BadArgument("file type not specified".to_string()))?
178                .to_str()
179                .ok_or_else(|| Error::BadArgument("file type empty".to_string()))?;
180            format_options.insert("type", file_type);
181        }
182        self.load_data(
183            sql,
184            Box::new(data),
185            size,
186            Some(format_options),
187            copy_options,
188        )
189        .await
190    }
191
192    async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats> {
193        info!("stream load: {}, length: {:?}", sql, data.len());
194        let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
195        for row in data {
196            wtr.write_record(row)
197                .map_err(|e| Error::BadArgument(e.to_string()))?;
198        }
199        let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
200        let size = bytes.len() as u64;
201        let reader = Box::new(std::io::Cursor::new(bytes));
202        let stats = self.load_data(sql, reader, size, None, None).await?;
203        Ok(stats)
204    }
205}
206
207impl<'o> RestAPIConnection {
208    pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
209        let client = APIClient::new(dsn, Some(name)).await?;
210        Ok(Self {
211            client: Arc::new(client),
212        })
213    }
214
215    async fn wait_for_schema(
216        &self,
217        resp: QueryResponse,
218        return_on_progress: bool,
219    ) -> Result<QueryResponse> {
220        if !resp.data.is_empty()
221            || !resp.schema.is_empty()
222            || (return_on_progress && resp.stats.progresses.has_progress())
223        {
224            return Ok(resp);
225        }
226        let node_id = resp.node_id.clone();
227        if let Some(node_id) = &node_id {
228            self.client.set_last_node_id(node_id.clone());
229        }
230        let mut result = resp;
231        // preserve schema since it is not included in the final response
232        while let Some(next_uri) = result.next_uri {
233            result = self
234                .client
235                .query_page(&result.id, &next_uri, &node_id)
236                .await?;
237
238            if !result.data.is_empty()
239                || !result.schema.is_empty()
240                || (return_on_progress && result.stats.progresses.has_progress())
241            {
242                break;
243            }
244        }
245        Ok(result)
246    }
247
248    fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
249        vec![
250            ("type", "CSV"),
251            ("field_delimiter", ","),
252            ("record_delimiter", "\n"),
253            ("skip_header", "0"),
254        ]
255        .into_iter()
256        .collect()
257    }
258
259    fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
260        vec![("purge", "true")].into_iter().collect()
261    }
262
263    pub async fn query_row_batch(&self, sql: &str) -> Result<RowBatch> {
264        let resp = self.client.start_query(sql).await?;
265        let resp = self.wait_for_schema(resp, false).await?;
266        RowBatch::from_response(self.client.clone(), resp)
267    }
268}
269
270type PageFut = Pin<Box<dyn Future<Output = Result<QueryResponse>> + Send>>;
271
272pub struct RestAPIRows<T> {
273    client: Arc<APIClient>,
274    schema: SchemaRef,
275    data: VecDeque<Vec<Option<String>>>,
276    stats: Option<ServerStats>,
277    query_id: String,
278    node_id: Option<String>,
279    next_uri: Option<String>,
280    next_page: Option<PageFut>,
281    _phantom: std::marker::PhantomData<T>,
282}
283
284impl<T> RestAPIRows<T> {
285    fn from_response(client: Arc<APIClient>, resp: QueryResponse) -> Result<(Schema, Self)> {
286        let schema: Schema = resp.schema.try_into()?;
287        let rows = Self {
288            client,
289            query_id: resp.id,
290            node_id: resp.node_id,
291            next_uri: resp.next_uri,
292            schema: Arc::new(schema.clone()),
293            data: resp.data.into(),
294            stats: Some(ServerStats::from(resp.stats)),
295            next_page: None,
296            _phantom: PhantomData,
297        };
298        Ok((schema, rows))
299    }
300}
301
302impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
303    type Item = Result<T>;
304
305    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
306        if let Some(ss) = self.stats.take() {
307            return Poll::Ready(Some(Ok(T::from_stats(ss))));
308        }
309        // Skip to fetch next page if there is only one row left in buffer.
310        // Therefore we could guarantee the `/final` called before the last row.
311        if self.data.len() > 1 {
312            if let Some(row) = self.data.pop_front() {
313                let row = T::try_from_row(row, self.schema.clone())?;
314                return Poll::Ready(Some(Ok(row)));
315            }
316        }
317        match self.next_page {
318            Some(ref mut next_page) => match Pin::new(next_page).poll(cx) {
319                Poll::Ready(Ok(resp)) => {
320                    if self.schema.fields().is_empty() {
321                        self.schema = Arc::new(resp.schema.try_into()?);
322                    }
323                    self.next_uri = resp.next_uri;
324                    self.next_page = None;
325                    let mut new_data = resp.data.into();
326                    self.data.append(&mut new_data);
327                    Poll::Ready(Some(Ok(T::from_stats(resp.stats.into()))))
328                }
329                Poll::Ready(Err(e)) => {
330                    self.next_page = None;
331                    Poll::Ready(Some(Err(e)))
332                }
333                Poll::Pending => Poll::Pending,
334            },
335            None => match self.next_uri {
336                Some(ref next_uri) => {
337                    let client = self.client.clone();
338                    let next_uri = next_uri.clone();
339                    let query_id = self.query_id.clone();
340                    let node_id = self.node_id.clone();
341                    self.next_page = Some(Box::pin(async move {
342                        client
343                            .query_page(&query_id, &next_uri, &node_id)
344                            .await
345                            .map_err(|e| e.into())
346                    }));
347                    self.poll_next(cx)
348                }
349                None => match self.data.pop_front() {
350                    Some(row) => {
351                        let row = T::try_from_row(row, self.schema.clone())?;
352                        Poll::Ready(Some(Ok(row)))
353                    }
354                    None => Poll::Ready(None),
355                },
356            },
357        }
358    }
359}
360
361trait FromRowStats: Send + Sync + Clone {
362    fn from_stats(stats: ServerStats) -> Self;
363    fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self>;
364}
365
366impl FromRowStats for RowWithStats {
367    fn from_stats(stats: ServerStats) -> Self {
368        RowWithStats::Stats(stats)
369    }
370
371    fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
372        Ok(RowWithStats::Row(Row::try_from((schema, row))?))
373    }
374}
375
376impl FromRowStats for RawRowWithStats {
377    fn from_stats(stats: ServerStats) -> Self {
378        RawRowWithStats::Stats(stats)
379    }
380
381    fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
382        let rows = Row::try_from((schema, row.clone()))?;
383        Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
384    }
385}
386
387pub struct RowBatch {
388    schema: Vec<SchemaField>,
389    client: Arc<APIClient>,
390    query_id: String,
391    node_id: Option<String>,
392
393    next_uri: Option<String>,
394    data: Vec<Vec<Option<String>>>,
395}
396
397impl RowBatch {
398    pub fn schema(&self) -> Vec<SchemaField> {
399        self.schema.clone()
400    }
401
402    fn from_response(client: Arc<APIClient>, mut resp: QueryResponse) -> Result<Self> {
403        Ok(Self {
404            schema: std::mem::take(&mut resp.schema),
405            client,
406            query_id: resp.id,
407            node_id: resp.node_id,
408            next_uri: resp.next_uri,
409            data: resp.data,
410        })
411    }
412
413    pub async fn fetch_next_page(&mut self) -> Result<Vec<Vec<Option<String>>>> {
414        if !self.data.is_empty() {
415            return Ok(std::mem::take(&mut self.data));
416        }
417        while let Some(next_uri) = &self.next_uri {
418            let resp = self
419                .client
420                .query_page(&self.query_id, next_uri, &self.node_id)
421                .await?;
422
423            self.next_uri = resp.next_uri;
424            if !resp.data.is_empty() {
425                return Ok(resp.data);
426            }
427        }
428        Ok(vec![])
429    }
430}