databend_driver/
rest_api.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, VecDeque};
16use std::future::Future;
17use std::marker::PhantomData;
18use std::path::Path;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23use async_trait::async_trait;
24use log::info;
25use tokio::fs::File;
26use tokio::io::BufReader;
27use tokio_stream::Stream;
28
29use databend_client::PresignedResponse;
30use databend_client::QueryResponse;
31use databend_client::{APIClient, SchemaField};
32use databend_driver_core::error::{Error, Result};
33use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
34use databend_driver_core::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, ServerStats};
35use databend_driver_core::schema::{Schema, SchemaRef};
36
37use crate::conn::{ConnectionInfo, IConnection, Reader};
38
39#[derive(Clone)]
40pub struct RestAPIConnection {
41    client: Arc<APIClient>,
42}
43
44#[async_trait]
45impl IConnection for RestAPIConnection {
46    async fn info(&self) -> ConnectionInfo {
47        ConnectionInfo {
48            handler: "RestAPI".to_string(),
49            host: self.client.host().to_string(),
50            port: self.client.port(),
51            user: self.client.username(),
52            database: self.client.current_database(),
53            warehouse: self.client.current_warehouse(),
54        }
55    }
56
57    fn last_query_id(&self) -> Option<String> {
58        self.client.last_query_id()
59    }
60
61    async fn close(&self) -> Result<()> {
62        self.client.close().await;
63        Ok(())
64    }
65
66    async fn exec(&self, sql: &str) -> Result<i64> {
67        info!("exec: {}", sql);
68        let mut resp = self.client.start_query(sql).await?;
69        let node_id = resp.node_id.clone();
70        if let Some(node_id) = &node_id {
71            self.client.set_last_node_id(node_id.clone());
72        }
73        while let Some(next_uri) = resp.next_uri {
74            resp = self
75                .client
76                .query_page(&resp.id, &next_uri, &node_id)
77                .await?;
78        }
79        Ok(resp.stats.progresses.write_progress.rows as i64)
80    }
81
82    async fn kill_query(&self, query_id: &str) -> Result<()> {
83        Ok(self.client.kill_query(query_id).await?)
84    }
85
86    async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
87        info!("query iter: {}", sql);
88        let rows_with_progress = self.query_iter_ext(sql).await?;
89        let rows = rows_with_progress.filter_rows().await;
90        Ok(rows)
91    }
92
93    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
94        info!("query iter ext: {}", sql);
95        let resp = self.client.start_query(sql).await?;
96        let resp = self.wait_for_schema(resp, true).await?;
97        let (schema, rows) = RestAPIRows::<RowWithStats>::from_response(self.client.clone(), resp)?;
98        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
99    }
100
101    // raw data response query, only for test
102    async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
103        info!("query raw iter: {}", sql);
104        let resp = self.client.start_query(sql).await?;
105        let resp = self.wait_for_schema(resp, true).await?;
106        let (schema, rows) =
107            RestAPIRows::<RawRowWithStats>::from_response(self.client.clone(), resp)?;
108        Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
109    }
110
111    async fn get_presigned_url(&self, operation: &str, stage: &str) -> Result<PresignedResponse> {
112        info!("get presigned url: {} {}", operation, stage);
113        let sql = format!("PRESIGN {} {}", operation, stage);
114        let row = self.query_row(&sql).await?.ok_or_else(|| {
115            Error::InvalidResponse("Empty response from server for presigned request".to_string())
116        })?;
117        let (method, headers, url): (String, String, String) =
118            row.try_into().map_err(Error::Parsing)?;
119        let headers: BTreeMap<String, String> = serde_json::from_str(&headers)?;
120        Ok(PresignedResponse {
121            method,
122            headers,
123            url,
124        })
125    }
126
127    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
128        self.client.upload_to_stage(stage, data, size).await?;
129        Ok(())
130    }
131
132    async fn load_data(
133        &self,
134        sql: &str,
135        data: Reader,
136        size: u64,
137        file_format_options: Option<BTreeMap<&str, &str>>,
138        copy_options: Option<BTreeMap<&str, &str>>,
139    ) -> Result<ServerStats> {
140        info!(
141            "load data: {}, size: {}, format: {:?}, copy: {:?}",
142            sql, size, file_format_options, copy_options
143        );
144        let now = chrono::Utc::now()
145            .timestamp_nanos_opt()
146            .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
147        let stage = format!("@~/client/load/{}", now);
148
149        let file_format_options =
150            file_format_options.unwrap_or_else(Self::default_file_format_options);
151        let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
152
153        self.upload_to_stage(&stage, data, size).await?;
154        let resp = self
155            .client
156            .insert_with_stage(sql, &stage, file_format_options, copy_options)
157            .await?;
158        Ok(ServerStats::from(resp.stats))
159    }
160
161    async fn load_file(
162        &self,
163        sql: &str,
164        fp: &Path,
165        format_options: Option<BTreeMap<&str, &str>>,
166        copy_options: Option<BTreeMap<&str, &str>>,
167    ) -> Result<ServerStats> {
168        info!(
169            "load file: {}, file: {:?}, format: {:?}, copy: {:?}",
170            sql, fp, format_options, copy_options
171        );
172        let file = File::open(fp).await?;
173        let metadata = file.metadata().await?;
174        let size = metadata.len();
175        let data = BufReader::new(file);
176        let mut format_options = format_options.unwrap_or_else(Self::default_file_format_options);
177        if !format_options.contains_key("type") {
178            let file_type = fp
179                .extension()
180                .ok_or_else(|| Error::BadArgument("file type not specified".to_string()))?
181                .to_str()
182                .ok_or_else(|| Error::BadArgument("file type empty".to_string()))?;
183            format_options.insert("type", file_type);
184        }
185        self.load_data(
186            sql,
187            Box::new(data),
188            size,
189            Some(format_options),
190            copy_options,
191        )
192        .await
193    }
194
195    async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats> {
196        info!("stream load: {}, length: {:?}", sql, data.len());
197        let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
198        for row in data {
199            wtr.write_record(row)
200                .map_err(|e| Error::BadArgument(e.to_string()))?;
201        }
202        let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
203        let size = bytes.len() as u64;
204        let reader = Box::new(std::io::Cursor::new(bytes));
205        let stats = self.load_data(sql, reader, size, None, None).await?;
206        Ok(stats)
207    }
208}
209
210impl<'o> RestAPIConnection {
211    pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
212        let client = APIClient::new(dsn, Some(name)).await?;
213        Ok(Self {
214            client: Arc::new(client),
215        })
216    }
217
218    async fn wait_for_schema(
219        &self,
220        resp: QueryResponse,
221        return_on_progress: bool,
222    ) -> Result<QueryResponse> {
223        let node_id = resp.node_id.clone();
224        if let Some(node_id) = &node_id {
225            self.client.set_last_node_id(node_id.clone());
226        }
227        if !resp.data.is_empty()
228            || !resp.schema.is_empty()
229            || (return_on_progress && resp.stats.progresses.has_progress())
230        {
231            return Ok(resp);
232        }
233        let mut result = resp;
234        // preserve schema since it is not included in the final response
235        while let Some(next_uri) = result.next_uri {
236            result = self
237                .client
238                .query_page(&result.id, &next_uri, &node_id)
239                .await?;
240
241            if !result.data.is_empty()
242                || !result.schema.is_empty()
243                || (return_on_progress && result.stats.progresses.has_progress())
244            {
245                break;
246            }
247        }
248        Ok(result)
249    }
250
251    fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
252        vec![
253            ("type", "CSV"),
254            ("field_delimiter", ","),
255            ("record_delimiter", "\n"),
256            ("skip_header", "0"),
257        ]
258        .into_iter()
259        .collect()
260    }
261
262    fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
263        vec![("purge", "true")].into_iter().collect()
264    }
265
266    pub async fn query_row_batch(&self, sql: &str) -> Result<RowBatch> {
267        let resp = self.client.start_query(sql).await?;
268        let resp = self.wait_for_schema(resp, false).await?;
269        RowBatch::from_response(self.client.clone(), resp)
270    }
271}
272
273type PageFut = Pin<Box<dyn Future<Output = Result<QueryResponse>> + Send>>;
274
275pub struct RestAPIRows<T> {
276    client: Arc<APIClient>,
277    schema: SchemaRef,
278    data: VecDeque<Vec<Option<String>>>,
279    stats: Option<ServerStats>,
280    query_id: String,
281    node_id: Option<String>,
282    next_uri: Option<String>,
283    next_page: Option<PageFut>,
284    _phantom: std::marker::PhantomData<T>,
285}
286
287impl<T> RestAPIRows<T> {
288    fn from_response(client: Arc<APIClient>, resp: QueryResponse) -> Result<(Schema, Self)> {
289        let schema: Schema = resp.schema.try_into()?;
290        let rows = Self {
291            client,
292            query_id: resp.id,
293            node_id: resp.node_id,
294            next_uri: resp.next_uri,
295            schema: Arc::new(schema.clone()),
296            data: resp.data.into(),
297            stats: Some(ServerStats::from(resp.stats)),
298            next_page: None,
299            _phantom: PhantomData,
300        };
301        Ok((schema, rows))
302    }
303}
304
305impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
306    type Item = Result<T>;
307
308    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309        if let Some(ss) = self.stats.take() {
310            return Poll::Ready(Some(Ok(T::from_stats(ss))));
311        }
312        // Skip to fetch next page if there is only one row left in buffer.
313        // Therefore we could guarantee the `/final` called before the last row.
314        if self.data.len() > 1 {
315            if let Some(row) = self.data.pop_front() {
316                let row = T::try_from_row(row, self.schema.clone())?;
317                return Poll::Ready(Some(Ok(row)));
318            }
319        }
320        match self.next_page {
321            Some(ref mut next_page) => match Pin::new(next_page).poll(cx) {
322                Poll::Ready(Ok(resp)) => {
323                    if self.schema.fields().is_empty() {
324                        self.schema = Arc::new(resp.schema.try_into()?);
325                    }
326                    self.next_uri = resp.next_uri;
327                    self.next_page = None;
328                    let mut new_data = resp.data.into();
329                    self.data.append(&mut new_data);
330                    Poll::Ready(Some(Ok(T::from_stats(resp.stats.into()))))
331                }
332                Poll::Ready(Err(e)) => {
333                    self.next_page = None;
334                    Poll::Ready(Some(Err(e)))
335                }
336                Poll::Pending => Poll::Pending,
337            },
338            None => match self.next_uri {
339                Some(ref next_uri) => {
340                    let client = self.client.clone();
341                    let next_uri = next_uri.clone();
342                    let query_id = self.query_id.clone();
343                    let node_id = self.node_id.clone();
344                    self.next_page = Some(Box::pin(async move {
345                        client
346                            .query_page(&query_id, &next_uri, &node_id)
347                            .await
348                            .map_err(|e| e.into())
349                    }));
350                    self.poll_next(cx)
351                }
352                None => match self.data.pop_front() {
353                    Some(row) => {
354                        let row = T::try_from_row(row, self.schema.clone())?;
355                        Poll::Ready(Some(Ok(row)))
356                    }
357                    None => Poll::Ready(None),
358                },
359            },
360        }
361    }
362}
363
364trait FromRowStats: Send + Sync + Clone {
365    fn from_stats(stats: ServerStats) -> Self;
366    fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self>;
367}
368
369impl FromRowStats for RowWithStats {
370    fn from_stats(stats: ServerStats) -> Self {
371        RowWithStats::Stats(stats)
372    }
373
374    fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
375        Ok(RowWithStats::Row(Row::try_from((schema, row))?))
376    }
377}
378
379impl FromRowStats for RawRowWithStats {
380    fn from_stats(stats: ServerStats) -> Self {
381        RawRowWithStats::Stats(stats)
382    }
383
384    fn try_from_row(row: Vec<Option<String>>, schema: SchemaRef) -> Result<Self> {
385        let rows = Row::try_from((schema, row.clone()))?;
386        Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
387    }
388}
389
390pub struct RowBatch {
391    schema: Vec<SchemaField>,
392    client: Arc<APIClient>,
393    query_id: String,
394    node_id: Option<String>,
395
396    next_uri: Option<String>,
397    data: Vec<Vec<Option<String>>>,
398}
399
400impl RowBatch {
401    pub fn schema(&self) -> Vec<SchemaField> {
402        self.schema.clone()
403    }
404
405    fn from_response(client: Arc<APIClient>, mut resp: QueryResponse) -> Result<Self> {
406        Ok(Self {
407            schema: std::mem::take(&mut resp.schema),
408            client,
409            query_id: resp.id,
410            node_id: resp.node_id,
411            next_uri: resp.next_uri,
412            data: resp.data,
413        })
414    }
415
416    pub async fn fetch_next_page(&mut self) -> Result<Vec<Vec<Option<String>>>> {
417        if !self.data.is_empty() {
418            return Ok(std::mem::take(&mut self.data));
419        }
420        while let Some(next_uri) = &self.next_uri {
421            let resp = self
422                .client
423                .query_page(&self.query_id, next_uri, &self.node_id)
424                .await?;
425
426            self.next_uri = resp.next_uri;
427            if !resp.data.is_empty() {
428                return Ok(resp.data);
429            }
430        }
431        Ok(vec![])
432    }
433}