Skip to main content

databend_driver/
rest_api.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use async_trait::async_trait;
16use log::info;
17use std::collections::{BTreeMap, VecDeque};
18use std::marker::PhantomData;
19use std::path::Path;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use std::time::Instant;
24use tokio::fs::File;
25use tokio::io::BufReader;
26use tokio_stream::Stream;
27
28use crate::client::LoadMethod;
29use crate::conn::{ConnectionInfo, IConnection, Reader};
30use databend_client::schema::{Schema, SchemaRef};
31use databend_client::Pages;
32use databend_client::{APIClient, ResultFormatSettings};
33use databend_driver_core::error::{Error, Result};
34use databend_driver_core::raw_rows::{RawRow, RawRowIterator, RawRowWithStats};
35use databend_driver_core::rows::{
36    Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats,
37};
38
39const LOAD_PLACEHOLDER: &str = "@_databend_load";
40
41#[derive(Clone)]
42pub struct RestAPIConnection {
43    client: Arc<APIClient>,
44}
45
46impl RestAPIConnection {
47    fn gen_temp_stage_location(&self) -> Result<String> {
48        let now = chrono::Utc::now()
49            .timestamp_nanos_opt()
50            .ok_or_else(|| Error::IO("Failed to get current timestamp".to_string()))?;
51        Ok(format!("@~/client/load/{now}"))
52    }
53
54    async fn load_data_with_stage(
55        &self,
56        sql: &str,
57        data: Reader,
58        size: u64,
59    ) -> Result<ServerStats> {
60        let location = self.gen_temp_stage_location()?;
61        self.upload_to_stage(&location, data, size).await?;
62        if self.client.capability().streaming_load {
63            let sql = sql.replace(LOAD_PLACEHOLDER, &location);
64            let page = self.client.query_all(&sql, None).await?;
65            Ok(ServerStats::from(page.stats))
66        } else {
67            let file_format_options = Self::default_file_format_options();
68            let copy_options = Self::default_copy_options();
69            let stats = self
70                .client
71                .insert_with_stage(sql, &location, file_format_options, copy_options)
72                .await?;
73            Ok(ServerStats::from(stats))
74        }
75    }
76
77    async fn load_data_with_streaming(
78        &self,
79        sql: &str,
80        data: Reader,
81        size: u64,
82    ) -> Result<ServerStats> {
83        let start = Instant::now();
84        let response = self
85            .client
86            .streaming_load(sql, data, "<no_filename>")
87            .await?;
88        Ok(ServerStats {
89            total_rows: 0,
90            total_bytes: 0,
91            read_rows: response.stats.rows,
92            read_bytes: size as usize,
93            write_rows: response.stats.rows,
94            write_bytes: response.stats.bytes,
95            running_time_ms: start.elapsed().as_millis() as f64,
96            spill_file_nums: 0,
97            spill_bytes: 0,
98        })
99    }
100    async fn load_data_with_options(
101        &self,
102        sql: &str,
103        data: Reader,
104        size: u64,
105        file_format_options: Option<BTreeMap<&str, &str>>,
106        copy_options: Option<BTreeMap<&str, &str>>,
107    ) -> Result<ServerStats> {
108        let location = self.gen_temp_stage_location()?;
109        let file_format_options =
110            file_format_options.unwrap_or_else(Self::default_file_format_options);
111        let copy_options = copy_options.unwrap_or_else(Self::default_copy_options);
112        self.upload_to_stage(&location, Box::new(data), size)
113            .await?;
114        let stats = self
115            .client
116            .insert_with_stage(sql, &location, file_format_options, copy_options)
117            .await?;
118        Ok(ServerStats::from(stats))
119    }
120}
121
122#[async_trait]
123impl IConnection for RestAPIConnection {
124    async fn info(&self) -> ConnectionInfo {
125        ConnectionInfo {
126            handler: "RestAPI".to_string(),
127            host: self.client.host().to_string(),
128            port: self.client.port(),
129            user: self.client.username(),
130            catalog: self.client.current_catalog(),
131            database: self.client.current_database(),
132            warehouse: self.client.current_warehouse(),
133        }
134    }
135
136    fn last_query_id(&self) -> Option<String> {
137        self.client.last_query_id()
138    }
139
140    fn set_warehouse(&self, warehouse: &str) -> Result<()> {
141        self.client.set_warehouse(warehouse.to_string());
142        Ok(())
143    }
144
145    fn set_database(&self, database: &str) -> Result<()> {
146        self.client.set_database(database.to_string());
147        Ok(())
148    }
149
150    fn set_role(&self, role: &str) -> Result<()> {
151        self.client.set_role(role.to_string());
152        Ok(())
153    }
154
155    fn set_session(&self, key: &str, value: &str) -> Result<()> {
156        self.client.set_session(key.to_string(), value.to_string());
157        Ok(())
158    }
159
160    async fn close(&self) -> Result<()> {
161        self.client.close().await;
162        Ok(())
163    }
164
165    fn close_with_spawn(&self) -> Result<()> {
166        self.client.close_with_spawn();
167        Ok(())
168    }
169
170    fn supports_server_side_params(&self) -> bool {
171        self.client.capability().server_side_params
172    }
173
174    async fn exec(&self, sql: &str) -> Result<i64> {
175        info!("exec: {}", sql);
176        let page = self.client.query_all(sql, None).await?;
177        Ok(page.stats.progresses.write_progress.rows as i64)
178    }
179
180    async fn exec_with_params(&self, sql: &str, params: Option<serde_json::Value>) -> Result<i64> {
181        info!("exec with params: {}", sql);
182        let page = self.client.query_all(sql, params).await?;
183        Ok(page.stats.progresses.write_progress.rows as i64)
184    }
185
186    async fn kill_query(&self, query_id: &str) -> Result<()> {
187        Ok(self.client.kill_query(query_id).await?)
188    }
189
190    async fn query_iter(&self, sql: &str) -> Result<RowIterator> {
191        info!("query iter: {}", sql);
192        let rows_with_progress = self.query_iter_ext(sql).await?;
193        let rows = rows_with_progress.filter_rows().await?;
194        Ok(rows)
195    }
196
197    async fn query_iter_ext(&self, sql: &str) -> Result<RowStatsIterator> {
198        info!("query iter ext: {}", sql);
199        let pages = self.client.start_query(sql, true, None).await?;
200        let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
201        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
202    }
203
204    async fn query_iter_with_params(
205        &self,
206        sql: &str,
207        params: Option<serde_json::Value>,
208    ) -> Result<RowIterator> {
209        info!("query iter with params: {}", sql);
210        let rows_with_progress = self.query_iter_ext_with_params(sql, params).await?;
211        let rows = rows_with_progress.filter_rows().await?;
212        Ok(rows)
213    }
214
215    async fn query_iter_ext_with_params(
216        &self,
217        sql: &str,
218        params: Option<serde_json::Value>,
219    ) -> Result<RowStatsIterator> {
220        info!("query iter ext with params: {}", sql);
221        let pages = self.client.start_query(sql, true, params).await?;
222        let (schema, rows) = RestAPIRows::<RowWithStats>::from_pages(pages).await?;
223        Ok(RowStatsIterator::new(Arc::new(schema), Box::pin(rows)))
224    }
225
226    // raw data response query, only for test
227    async fn query_raw_iter(&self, sql: &str) -> Result<RawRowIterator> {
228        info!("query raw iter: {}", sql);
229        let pages = self.client.start_query(sql, true, None).await?;
230        let (schema, rows) = RestAPIRows::<RawRowWithStats>::from_pages(pages).await?;
231        Ok(RawRowIterator::new(Arc::new(schema), Box::pin(rows)))
232    }
233
234    async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
235        self.client.upload_to_stage(stage, data, size).await?;
236        Ok(())
237    }
238
239    async fn load_data(
240        &self,
241        sql: &str,
242        data: Reader,
243        size: u64,
244        method: LoadMethod,
245    ) -> Result<ServerStats> {
246        let sql = sql.trim_end();
247        let sql = sql.trim_end_matches(';');
248        info!("load data: {}, size: {}, method: {method:?}", sql, size);
249        let sql_low = sql.to_lowercase();
250        let has_place_holder = sql_low.contains(LOAD_PLACEHOLDER);
251        let sql = match (self.client.capability().streaming_load, has_place_holder) {
252            (false, false) => {
253                // todo: deprecate this later
254                return self
255                    .load_data_with_options(sql, data, size, None, None)
256                    .await;
257            }
258            (false, true) => return Err(Error::BadArgument(
259                "Please upgrade your server to >= 1.2.781 to support insert from @_databend_load"
260                    .to_string(),
261            )),
262            (true, false) => {
263                format!("{sql} from @_databend_load file_format=(type=csv)")
264            }
265            (true, true) => sql.to_string(),
266        };
267
268        match method {
269            LoadMethod::Streaming => self.load_data_with_streaming(&sql, data, size).await,
270            LoadMethod::Stage => self.load_data_with_stage(&sql, data, size).await,
271        }
272    }
273
274    async fn load_file(&self, sql: &str, fp: &Path, method: LoadMethod) -> Result<ServerStats> {
275        info!("load file: {}, file: {:?}", sql, fp,);
276        let file = File::open(fp).await?;
277        let metadata = file.metadata().await?;
278        let size = metadata.len();
279        let data = BufReader::new(file);
280        self.load_data(sql, Box::new(data), size, method).await
281    }
282
283    async fn load_file_with_options(
284        &self,
285        sql: &str,
286        fp: &Path,
287        file_format_options: Option<BTreeMap<&str, &str>>,
288        copy_options: Option<BTreeMap<&str, &str>>,
289    ) -> Result<ServerStats> {
290        let file = File::open(fp).await?;
291        let metadata = file.metadata().await?;
292        let size = metadata.len();
293        let data = BufReader::new(file);
294        self.load_data_with_options(sql, Box::new(data), size, file_format_options, copy_options)
295            .await
296    }
297
298    async fn stream_load(
299        &self,
300        sql: &str,
301        data: Vec<Vec<&str>>,
302        method: LoadMethod,
303    ) -> Result<ServerStats> {
304        info!("stream load: {}; rows: {:?}", sql, data.len());
305        let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
306        for row in data {
307            wtr.write_record(row)
308                .map_err(|e| Error::BadArgument(e.to_string()))?;
309        }
310        let bytes = wtr.into_inner().map_err(|e| Error::IO(e.to_string()))?;
311        let size = bytes.len() as u64;
312        let reader = Box::new(std::io::Cursor::new(bytes));
313        let stats = if self.client.capability().streaming_load {
314            let sql = format!("{sql} from @_databend_load file_format = (type = csv)");
315            self.load_data(&sql, reader, size, method).await?
316        } else {
317            self.load_data_with_options(sql, reader, size, None, None)
318                .await?
319        };
320        Ok(stats)
321    }
322}
323
324impl<'o> RestAPIConnection {
325    pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
326        let client = APIClient::new(dsn, Some(name)).await?;
327        Ok(Self { client })
328    }
329
330    fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
331        vec![
332            ("type", "CSV"),
333            ("field_delimiter", ","),
334            ("record_delimiter", "\n"),
335            ("skip_header", "0"),
336        ]
337        .into_iter()
338        .collect()
339    }
340
341    fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
342        vec![("purge", "true")].into_iter().collect()
343    }
344}
345
346pub struct RestAPIRows<T> {
347    pages: Pages,
348
349    schema: SchemaRef,
350    settings: ResultFormatSettings,
351
352    data: VecDeque<Vec<Option<String>>>,
353    rows: VecDeque<Row>,
354
355    stats: Option<ServerStats>,
356
357    _phantom: std::marker::PhantomData<T>,
358}
359
360impl<T> RestAPIRows<T> {
361    async fn from_pages(pages: Pages) -> Result<(Schema, Self)> {
362        let (pages, schema, settings) = pages.wait_for_schema(true).await?;
363        let rows = Self {
364            pages,
365            schema: Arc::new(schema.clone()),
366            settings,
367            data: Default::default(),
368            rows: Default::default(),
369            stats: None,
370            _phantom: PhantomData,
371        };
372        Ok((schema, rows))
373    }
374}
375
376impl<T: FromRowStats + std::marker::Unpin> Stream for RestAPIRows<T> {
377    type Item = Result<T>;
378
379    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
380        if let Some(ss) = self.stats.take() {
381            return Poll::Ready(Some(Ok(T::from_stats(ss))));
382        }
383        // Skip to fetch next page if there is only one row left in buffer.
384        // Therefore, we could guarantee the `/final` called before the last row.
385        if self.data.len() > 1 {
386            if let Some(row) = self.data.pop_front() {
387                let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
388                return Poll::Ready(Some(Ok(row)));
389            }
390        } else if self.rows.len() > 1 {
391            if let Some(row) = self.rows.pop_front() {
392                let row = T::from_row(row);
393                return Poll::Ready(Some(Ok(row)));
394            }
395        }
396
397        match Pin::new(&mut self.pages).poll_next(cx) {
398            Poll::Ready(Some(Ok(page))) => {
399                if self.schema.fields().is_empty() {
400                    if !page.raw_schema.is_empty() {
401                        self.schema = Arc::new(page.raw_schema.try_into()?);
402                    } else if !page.batches.is_empty() {
403                        self.schema = Arc::new(page.batches[0].schema().clone().try_into()?);
404                    }
405                }
406                if page.batches.is_empty() {
407                    let mut new_data = page.data.into();
408                    self.data.append(&mut new_data);
409                } else {
410                    for batch in page.batches.into_iter() {
411                        let rows = Rows::try_from((batch, self.settings.clone()))?;
412                        self.rows.extend(rows);
413                    }
414                }
415                Poll::Ready(Some(Ok(T::from_stats(page.stats.into()))))
416            }
417            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
418            Poll::Ready(None) => {
419                if let Some(row) = self.rows.pop_front() {
420                    let row = T::from_row(row);
421                    Poll::Ready(Some(Ok(row)))
422                } else if let Some(row) = self.data.pop_front() {
423                    let row = T::try_from_raw_row(row, self.schema.clone(), &self.settings)?;
424                    Poll::Ready(Some(Ok(row)))
425                } else {
426                    Poll::Ready(None)
427                }
428            }
429            Poll::Pending => Poll::Pending,
430        }
431    }
432}
433
434trait FromRowStats: Send + Sync + Clone {
435    fn from_stats(stats: ServerStats) -> Self;
436    fn try_from_raw_row(
437        row: Vec<Option<String>>,
438        schema: SchemaRef,
439        tz: &ResultFormatSettings,
440    ) -> Result<Self>;
441    fn from_row(row: Row) -> Self;
442}
443
444impl FromRowStats for RowWithStats {
445    fn from_stats(stats: ServerStats) -> Self {
446        RowWithStats::Stats(stats)
447    }
448
449    fn try_from_raw_row(
450        row: Vec<Option<String>>,
451        schema: SchemaRef,
452        settings: &ResultFormatSettings,
453    ) -> Result<Self> {
454        Ok(RowWithStats::Row(Row::try_from((schema, row, settings))?))
455    }
456    fn from_row(row: Row) -> Self {
457        RowWithStats::Row(row)
458    }
459}
460
461impl FromRowStats for RawRowWithStats {
462    fn from_stats(stats: ServerStats) -> Self {
463        RawRowWithStats::Stats(stats)
464    }
465
466    fn try_from_raw_row(
467        row: Vec<Option<String>>,
468        schema: SchemaRef,
469        tz: &ResultFormatSettings,
470    ) -> Result<Self> {
471        let rows = Row::try_from((schema, row.clone(), tz))?;
472        Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
473    }
474
475    fn from_row(row: Row) -> Self {
476        RawRowWithStats::Row(RawRow::from(row))
477    }
478}